Is it possible for a function called from summarise to know which group it is being run on?

I'd like to write a function that knows which group it is running on when called from dplyr::summarise on a grouped tibble. The best way I can think of to do this is to pass in the grouping variable as a default argument to the function, but I've been unable to find a way to get a default argument to be evaluated in the data mask. I've found one work around that modifies the call, but I'm hoping there's a better way.

I know this is a weird thing to want to do, so let me try to explain why I think I need to do it. This is for the srvyr package which wraps survey package functions in dplyr syntax. Up to now, srvyr's methods for summarise() don't call the dplyr methods, but instead it rewrite them. However, this has been a major source of bugs and also I believe prevents it from allowing dplyr::across() to work. I'm hoping to rewrite so that I'm using the dplyr methods, but for grouped surveys I'll need some way to subset the survey weights per group. The user does not specify thes weights in each function call, though, they are defined as part of the survey object.

Here's some code that I think shows what I'm trying to do better than I can write down in words.

suppressPackageStartupMessages({
  library(dplyr)
  library(rlang)
})
options(dplyr.summarise.inform = FALSE)

df <- tibble(
  val = 1:4,
  grp = c("a", "a", "b", "b")
)

# this works...
custom_sum <- function(x, group) {
  sum(case_when(
    group == "a" ~ x * 1,
    group == "b" ~ x * 2
  ))
}

df %>%
  group_by(grp) %>% 
  summarize(wt_val = custom_sum(val, grp))
#> # A tibble: 2 x 2
#>   grp   wt_val
#>   <chr>  <dbl>
#> 1 a          3
#> 2 b         14


# But I want to the 2nd argument to get a default value from the data mask. 

# I've tried various default values for group:
# - .data pronoun (shown below)
# - quo, sym, expr with and without unquoting
custom_sum2 <- function(x, group = .data[["grp"]]) {
  sum(case_when(
    group == "a" ~ x * 1,
    group == "b" ~ x * 2
  ))
}

df %>%
  group_by(grp) %>% 
  summarize(wt_val = custom_sum2(val))
#> Error: Problem with `summarise()` input `wt_val`.
#> x `group == "a" ~ x * 1`, `group == "b" ~ x * 2` must be length 0 or one, not 2.
#> Backtrace:
#>      █
#>   1. ├─df %>% group_by(grp) %>% summarize(wt_val = custom_sum2(val))
#>   2. │ ├─base::withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
#>   3. │ └─base::eval(quote(`_fseq`(`_lhs`)), env, env)
#>   4. │   └─base::eval(quote(`_fseq`(`_lhs`)), env, env)
#>   5. │     └─`_fseq`(`_lhs`)
#>   6. │       └─magrittr::freduce(value, `_function_list`)
#>   7. │         ├─base::withVisible(function_list[[k]](value))
#>   8. │         └─function_list[[k]](value)
#>   9. │           ├─dplyr::summarize(., wt_val = custom_sum2(val))
#>  10. │           └─dplyr:::summarise.grouped_df(., wt_val = custom_sum2(val))
#>  11. │             └─dplyr:::summarise_cols(.data, ...)
#>  12. │               ├─base::tryCatch(...)
#>  13. │               │ └─base:::tryCatchList(expr, classes, parentenv, handlers)
#>  14. │               │   └─base:::tryCatchOne(expr, names, parentenv, handlers[[1L]])
#>  15. │               │     └─base:::doTryCatch(return(expr), name, parentenv, handler)
#>  16. │               └─mask$eval_all_summarise(quo)
#>  17. └─global::custom_sum2(val)
#>  18.   └─dplyr::case_when(group == "a" ~ x * 1, group == "b" ~ x * 2)
#>  19.     └─dplyr:::validate_case_when_length(query, value, fs)
#>  20.       └─dplyr:::bad_calls(...)
#>  21.         └─dplyr:::glubort(fmt_calls(calls), ..., .envir = .envir)
#> ℹ Input `wt_val` is `custom_sum2(val)`.
#> ℹ The error occured in group 1: grp = "a".
#> Backtrace:
#>      █
#>   1. └─df %>% group_by(grp) %>% summarize(wt_val = custom_sum2(val))
#>   2.   ├─base::withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
#>   3.   └─base::eval(quote(`_fseq`(`_lhs`)), env, env)
#>   4.     └─base::eval(quote(`_fseq`(`_lhs`)), env, env)
#>   5.       └─`_fseq`(`_lhs`)
#>   6.         └─magrittr::freduce(value, `_function_list`)
#>   7.           ├─base::withVisible(function_list[[k]](value))
#>   8.           └─function_list[[k]](value)
#>   9.             ├─dplyr::summarize(., wt_val = custom_sum2(val))
#>  10.             └─dplyr:::summarise.grouped_df(., wt_val = custom_sum2(val))
#>  11.               └─dplyr:::summarise_cols(.data, ...)
#> <parent: error/rlang_error>
#> Backtrace:
#>     █
#>  1. ├─mask$eval_all_summarise(quo)
#>  2. └─global::custom_sum2(val)
#>  3.   └─dplyr::case_when(group == "a" ~ x * 1, group == "b" ~ x * 2)
#>  4.     └─dplyr:::validate_case_when_length(query, value, fs)
#>  5.       └─dplyr:::bad_calls(...)
#>  6.         └─dplyr:::glubort(fmt_calls(calls), ..., .envir = .envir)


# I can modify the call, which kind of works for my purpose, 
# but isn't very robust and feels hacky
summarise.custom_tbl <- function(.data, ..., .groups = NULL) {
  .dots <- rlang::quos(...)
  
  new_dots <- lapply(.dots, function(dot) {
    new_expr <- call_modify(get_expr(dot), group = sym("grp"))
    set_expr(dot, new_expr)
  })
  
  class(.data) <- setdiff(class(.data), "custom_tbl")
  dplyr::summarise(.data, !!!new_dots, .groups = .groups)
}

df %>%
  group_by(grp) %>% 
  structure(., class = c("custom_tbl", class(.))) %>% 
  summarize(wt_val = custom_sum(val))
#> # A tibble: 2 x 2
#>   grp   wt_val
#>   <chr>  <dbl>
#> 1 a          3
#> 2 b         14

Thanks for taking a look!

Oh man, totally missed dplyr::cur_group_id(), which does exactly what I want.

6 Likes

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.