I'd like to write a function that knows which group it is running on when called from dplyr::summarise
on a grouped tibble. The best way I can think of to do this is to pass in the grouping variable as a default argument to the function, but I've been unable to find a way to get a default argument to be evaluated in the data mask. I've found one work around that modifies the call, but I'm hoping there's a better way.
I know this is a weird thing to want to do, so let me try to explain why I think I need to do it. This is for the srvyr
package which wraps survey
package functions in dplyr
syntax. Up to now, srvyr's methods for summarise()
don't call the dplyr methods, but instead it rewrite them. However, this has been a major source of bugs and also I believe prevents it from allowing dplyr::across()
to work. I'm hoping to rewrite so that I'm using the dplyr methods, but for grouped surveys I'll need some way to subset the survey weights per group. The user does not specify thes weights in each function call, though, they are defined as part of the survey object.
Here's some code that I think shows what I'm trying to do better than I can write down in words.
suppressPackageStartupMessages({
library(dplyr)
library(rlang)
})
options(dplyr.summarise.inform = FALSE)
df <- tibble(
val = 1:4,
grp = c("a", "a", "b", "b")
)
# this works...
custom_sum <- function(x, group) {
sum(case_when(
group == "a" ~ x * 1,
group == "b" ~ x * 2
))
}
df %>%
group_by(grp) %>%
summarize(wt_val = custom_sum(val, grp))
#> # A tibble: 2 x 2
#> grp wt_val
#> <chr> <dbl>
#> 1 a 3
#> 2 b 14
# But I want to the 2nd argument to get a default value from the data mask.
# I've tried various default values for group:
# - .data pronoun (shown below)
# - quo, sym, expr with and without unquoting
custom_sum2 <- function(x, group = .data[["grp"]]) {
sum(case_when(
group == "a" ~ x * 1,
group == "b" ~ x * 2
))
}
df %>%
group_by(grp) %>%
summarize(wt_val = custom_sum2(val))
#> Error: Problem with `summarise()` input `wt_val`.
#> x `group == "a" ~ x * 1`, `group == "b" ~ x * 2` must be length 0 or one, not 2.
#> Backtrace:
#> █
#> 1. ├─df %>% group_by(grp) %>% summarize(wt_val = custom_sum2(val))
#> 2. │ ├─base::withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
#> 3. │ └─base::eval(quote(`_fseq`(`_lhs`)), env, env)
#> 4. │ └─base::eval(quote(`_fseq`(`_lhs`)), env, env)
#> 5. │ └─`_fseq`(`_lhs`)
#> 6. │ └─magrittr::freduce(value, `_function_list`)
#> 7. │ ├─base::withVisible(function_list[[k]](value))
#> 8. │ └─function_list[[k]](value)
#> 9. │ ├─dplyr::summarize(., wt_val = custom_sum2(val))
#> 10. │ └─dplyr:::summarise.grouped_df(., wt_val = custom_sum2(val))
#> 11. │ └─dplyr:::summarise_cols(.data, ...)
#> 12. │ ├─base::tryCatch(...)
#> 13. │ │ └─base:::tryCatchList(expr, classes, parentenv, handlers)
#> 14. │ │ └─base:::tryCatchOne(expr, names, parentenv, handlers[[1L]])
#> 15. │ │ └─base:::doTryCatch(return(expr), name, parentenv, handler)
#> 16. │ └─mask$eval_all_summarise(quo)
#> 17. └─global::custom_sum2(val)
#> 18. └─dplyr::case_when(group == "a" ~ x * 1, group == "b" ~ x * 2)
#> 19. └─dplyr:::validate_case_when_length(query, value, fs)
#> 20. └─dplyr:::bad_calls(...)
#> 21. └─dplyr:::glubort(fmt_calls(calls), ..., .envir = .envir)
#> ℹ Input `wt_val` is `custom_sum2(val)`.
#> ℹ The error occured in group 1: grp = "a".
#> Backtrace:
#> █
#> 1. └─df %>% group_by(grp) %>% summarize(wt_val = custom_sum2(val))
#> 2. ├─base::withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
#> 3. └─base::eval(quote(`_fseq`(`_lhs`)), env, env)
#> 4. └─base::eval(quote(`_fseq`(`_lhs`)), env, env)
#> 5. └─`_fseq`(`_lhs`)
#> 6. └─magrittr::freduce(value, `_function_list`)
#> 7. ├─base::withVisible(function_list[[k]](value))
#> 8. └─function_list[[k]](value)
#> 9. ├─dplyr::summarize(., wt_val = custom_sum2(val))
#> 10. └─dplyr:::summarise.grouped_df(., wt_val = custom_sum2(val))
#> 11. └─dplyr:::summarise_cols(.data, ...)
#> <parent: error/rlang_error>
#> Backtrace:
#> █
#> 1. ├─mask$eval_all_summarise(quo)
#> 2. └─global::custom_sum2(val)
#> 3. └─dplyr::case_when(group == "a" ~ x * 1, group == "b" ~ x * 2)
#> 4. └─dplyr:::validate_case_when_length(query, value, fs)
#> 5. └─dplyr:::bad_calls(...)
#> 6. └─dplyr:::glubort(fmt_calls(calls), ..., .envir = .envir)
# I can modify the call, which kind of works for my purpose,
# but isn't very robust and feels hacky
summarise.custom_tbl <- function(.data, ..., .groups = NULL) {
.dots <- rlang::quos(...)
new_dots <- lapply(.dots, function(dot) {
new_expr <- call_modify(get_expr(dot), group = sym("grp"))
set_expr(dot, new_expr)
})
class(.data) <- setdiff(class(.data), "custom_tbl")
dplyr::summarise(.data, !!!new_dots, .groups = .groups)
}
df %>%
group_by(grp) %>%
structure(., class = c("custom_tbl", class(.))) %>%
summarize(wt_val = custom_sum(val))
#> # A tibble: 2 x 2
#> grp wt_val
#> <chr> <dbl>
#> 1 a 3
#> 2 b 14
Thanks for taking a look!