From 8699095c00be0cc93e82aacfb69ceaf3af9a200e Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Tue, 16 Aug 2022 15:18:54 -0500 Subject: [PATCH 1/2] Simpler approach to slice_head/slice_tail --- R/slice.R | 24 ++++++------------------ tests/testthat/test-slice.r | 2 +- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/R/slice.R b/R/slice.R index 4f8f4549df..cd1e38ac15 100644 --- a/R/slice.R +++ b/R/slice.R @@ -139,16 +139,10 @@ slice_head <- function(.data, ..., n, prop) { #' @export slice_head.data.frame <- function(.data, ..., n, prop) { size <- get_slice_size(n = n, prop = prop) - idx <- function(n) { - to <- size(n) - if (to > n) { - to <- n - } - seq2(1, to) - } - dplyr_local_error_call() - slice(.data, idx(dplyr::n())) + group_idx <- group_rows(.data) + slice_idx <- unlist(lapply(group_idx, function(x) head(x, size(length(x))))) + dplyr_row_slice(.data, slice_idx) } #' @export @@ -162,16 +156,10 @@ slice_tail <- function(.data, ..., n, prop) { #' @export slice_tail.data.frame <- function(.data, ..., n, prop) { size <- get_slice_size(n = n, prop = prop) - idx <- function(n) { - from <- n - size(n) + 1 - if (from < 1L) { - from <- 1L - } - seq2(from, n) - } - dplyr_local_error_call() - slice(.data, idx(dplyr::n())) + group_idx <- group_rows(.data) + slice_idx <- unlist(lapply(group_idx, function(x) tail(x, size(length(x))))) + dplyr_row_slice(.data, slice_idx) } #' @export diff --git a/tests/testthat/test-slice.r b/tests/testthat/test-slice.r index 6da205ff3c..7e6cb899d3 100644 --- a/tests/testthat/test-slice.r +++ b/tests/testthat/test-slice.r @@ -186,7 +186,7 @@ test_that("slice_*() checks that `n=` is explicitly named and ... is empty", { test_that("slice_helpers do call slice() and benefit from dispatch (#6084)", { local_methods( - slice.noisy = function(.data, ..., .preserve = FALSE) { + dplyr_row_slice.noisy = function(.data, ..., .preserve = FALSE) { warning("noisy") NextMethod() } From f7b80ee97cb5f13b7b265327a1f4d4c8d1dd3c56 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Tue, 16 Aug 2022 15:44:03 -0500 Subject: [PATCH 2/2] Alternative implementation for slice_sample() --- R/slice.R | 40 +++++++++++++++++++--------------- tests/testthat/_snaps/slice.md | 7 +++--- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/R/slice.R b/R/slice.R index cd1e38ac15..f41609d3f7 100644 --- a/R/slice.R +++ b/R/slice.R @@ -141,8 +141,8 @@ slice_head.data.frame <- function(.data, ..., n, prop) { size <- get_slice_size(n = n, prop = prop) group_idx <- group_rows(.data) - slice_idx <- unlist(lapply(group_idx, function(x) head(x, size(length(x))))) - dplyr_row_slice(.data, slice_idx) + slice_idx <- lapply(group_idx, function(x) head(x, size(length(x)))) + dplyr_row_slice(.data, unlist(slice_idx)) } #' @export @@ -158,8 +158,8 @@ slice_tail.data.frame <- function(.data, ..., n, prop) { size <- get_slice_size(n = n, prop = prop) group_idx <- group_rows(.data) - slice_idx <- unlist(lapply(group_idx, function(x) tail(x, size(length(x))))) - dplyr_row_slice(.data, slice_idx) + slice_idx <- lapply(group_idx, function(x) tail(x, size(length(x)))) + dplyr_row_slice(.data, unlist(slice_idx)) } #' @export @@ -253,16 +253,20 @@ slice_sample <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE) slice_sample.data.frame <- function(.data, ..., n, prop, weight_by = NULL, replace = FALSE) { size <- get_slice_size(n = n, prop = prop, allow_negative = FALSE) - dplyr_local_error_call() - slice(.data, local({ - weight_by <- {{ weight_by }} + if (!missing(weight_by)) { + weight_by <- transmute(.data, ..weight_by = {{ weight_by }})[[1]] + } - n <- dplyr::n() - if (!is.null(weight_by)) { - weight_by <- vec_assert(weight_by, size = n, arg = "weight_by") - } - sample_int(n, size(n), replace = replace, wt = weight_by) - })) + group_idx <- group_rows(.data) + slice_idx <- vector("list", length(group_idx)) + for (i in seq_along(group_idx)) { + idx <- group_idx[[i]] + n <- size(length(idx)) + + slice_idx[[i]] <- sample_int(idx, n, replace = replace, wt = weight_by[idx]) + } + + dplyr_row_slice(.data, unlist(slice_idx)) } # helpers ----------------------------------------------------------------- @@ -454,15 +458,15 @@ get_slice_size <- function(n, prop, allow_negative = TRUE, error_call = caller_e } } -sample_int <- function(n, size, replace = FALSE, wt = NULL, call = caller_env()) { - if (!replace && n < size) { - size <- n +sample_int <- function(x, size, replace = FALSE, wt = NULL, call = caller_env()) { + if (!replace && length(x) < size) { + size <- length(x) } if (size == 0L) { - integer(0) + x[integer(0)] } else { - sample.int(n, size, prob = wt, replace = replace) + x[sample.int(length(x), size, prob = wt, replace = replace)] } } diff --git a/tests/testthat/_snaps/slice.md b/tests/testthat/_snaps/slice.md index 5e246eba73..3cbd141b2c 100644 --- a/tests/testthat/_snaps/slice.md +++ b/tests/testthat/_snaps/slice.md @@ -234,10 +234,9 @@ Code slice_sample(df, n = 2, weight_by = 1:6) Condition - Error in `slice_sample()`: - ! Problem while computing indices. - Caused by error: - ! `weight_by` must have size 10, not size 6. + Error in `transmute()`: + ! Problem while computing `..weight_by = 1:6`. + x `..weight_by` must be size 10 or 1, not 6. # `slice_sample()` validates `replace`