Tidymodels: how to create custom metric for use with tune_grid() that can use grouped data.frame/tibble?

What I'd like to do

I am trying to build a model in tidymodels that will predict the efficacy of drugs on cell lines (like bacteria). The model will rank drugs by efficacy for a given cell line, so I want to use Spearman's correlation (ρ) as a metric. In the following example data set, each cell line (column Sample) is represented by a letter, Q, R, S, ..., Z, and each sample was treated with 50 drugs.

When I split the data for cross-validation, the training/test splits for each fold will have >1 cell line (e.g. Q, R in the test split for fold 1), but in calculating the metric (ρ), I want to calculate it for each cell line individually and then take the average across all the cell lines in the test split, rather than for all the observations in aggregate. For example, if the test split for fold 1 consists of Q, R, then I want to calculate ρ for the 50 drugs tested against Q, then a separate ρ for the 50 drugs tested against R, average these two ρ, and have that average be the metric calculated for fold 1.

What I've tried

I was thinking that I'd have to calculate the metric on a tibble/data.frame grouped by the Sample column, but I can't figure out how to pass that variable into tune_grid(). I don't think I can include the variable in add_formula() when creating the workflow object, since I don't want it as a predictor variable. I just discovered tidymodels yesterday, so maybe there's a straightforward solution I'm unaware of, but I haven't been able to find anything on Google so far. The code below is what I've tried, but obviously it doesn't work. Thank you in advance for any advice you can give.

Error

i Resample1: preprocessor 1/1
✓ Resample1: preprocessor 1/1
i Resample1: preprocessor 1/1, model 1/20
✓ Resample1: preprocessor 1/1, model 1/20
i Resample1: preprocessor 1/1, model 1/20 (predictions)
x Resample1: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample2: preprocessor 1/1
✓ Resample2: preprocessor 1/1
i Resample2: preprocessor 1/1, model 1/20
✓ Resample2: preprocessor 1/1, model 1/20
i Resample2: preprocessor 1/1, model 1/20 (predictions)
x Resample2: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample3: preprocessor 1/1
✓ Resample3: preprocessor 1/1
i Resample3: preprocessor 1/1, model 1/20
✓ Resample3: preprocessor 1/1, model 1/20
i Resample3: preprocessor 1/1, model 1/20 (predictions)
x Resample3: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample4: preprocessor 1/1
✓ Resample4: preprocessor 1/1
i Resample4: preprocessor 1/1, model 1/20
✓ Resample4: preprocessor 1/1, model 1/20
i Resample4: preprocessor 1/1, model 1/20 (predictions)
x Resample4: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample5: preprocessor 1/1
✓ Resample5: preprocessor 1/1
i Resample5: preprocessor 1/1, model 1/20
✓ Resample5: preprocessor 1/1, model 1/20
i Resample5: preprocessor 1/1, model 1/20 (predictions)
x Resample5: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
Warning message:
All models failed. See the `.notes` column. 

Upon running glmnet_tuning_results:

Warning message:
This tuning result has notes. Example notes on model fitting include:
internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm = ~na_rm)
internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm = ~na_rm)
internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm = ~na_rm)

Code

Example data set

data = tibble(
  Sample = rep(LETTERS[17:26], each = 50),
  TargetVariable = rnorm(500, mean = 0, sd = 1),
  PredictorVariable1 = rnorm(500, mean = 5, sd = 1),
  PredictorVariable2 = rpois(500, lambda = 5)
)

Model

# Splitting for cross-validation.
set.seed(1026)
folds = group_vfold_cv(data, group = Sample, v = 5)

# Model specification.
glmnet_model = linear_reg(
  mode    = "regression", 
  penalty = tune(), 
  mixture = tune()
) %>%
  set_engine("glmnet")

# Workflow.
glmnet_wf = workflow() %>%
  add_model(glmnet_model) %>% 
  add_formula(TargetVariable ~ . - Sample)

# Grid specification.
glmnet_params = parameters(penalty(), mixture())
set.seed(1026)
glmnet_grid = grid_max_entropy(glmnet_params, size = 20)

# Hyperparameter tuning.
glmnet_tuning_results = tune_grid(
  glmnet_wf,
  resamples = folds,
  grid      = glmnet_grid,
  metrics   = metric_set(spearman_cor),
  control   = control_grid(verbose = TRUE)
)

glmnet_tuning_results %>% show_best(n = 10)

Custom metric

# Vector version.
spearman_cor_vec = function(truth, estimate, na_rm = TRUE) {
  
  spearman_cor_impl = function(truth, estimate) {
    cor(truth, estimate, method = "spearman")
  }
  
  metric_vec_template(
    metric_impl = spearman_cor_impl,
    truth = truth, 
    estimate = estimate,
    na_rm = na_rm,
    cls = "numeric"
  )
}
# Data frame version. 
spearman_cor = function(data) {
  UseMethod("spearman_cor")
}

spearman_cor = new_numeric_metric(spearman_cor, direction = "maximize")

spearman_cor.data.frame = function(data, truth, estimate, na_rm = TRUE) {
  
  data_grouped = data %>%
    group_by(Sample)
  
  metric_summarizer(
    metric_nm = "spearman_cor",
    metric_fn = spearman_cor_vec,
    data = data_grouped,
    truth = !! enquo(truth),
    estimate = !! enquo(estimate), 
    na_rm = na_rm
  )
  
}

Session info

sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 3.6.3 (2020-02-29)
#>  os       macOS Catalina 10.15.7      
#>  system   x86_64, darwin15.6.0        
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       America/Chicago             
#>  date     2021-08-25                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version date       lib source        
#>  backports     1.1.6   2020-04-05 [1] CRAN (R 3.6.2)
#>  cli           3.0.1   2021-07-17 [1] CRAN (R 3.6.2)
#>  crayon        1.3.4   2017-09-16 [1] CRAN (R 3.6.0)
#>  digest        0.6.25  2020-02-23 [1] CRAN (R 3.6.0)
#>  ellipsis      0.3.2   2021-04-29 [1] CRAN (R 3.6.2)
#>  evaluate      0.14    2019-05-28 [1] CRAN (R 3.6.0)
#>  fansi         0.4.1   2020-01-08 [1] CRAN (R 3.6.0)
#>  fs            1.3.1   2019-05-06 [1] CRAN (R 3.6.0)
#>  glue          1.4.0   2020-04-03 [1] CRAN (R 3.6.2)
#>  highr         0.8     2019-03-20 [1] CRAN (R 3.6.0)
#>  htmltools     0.5.1.1 2021-01-22 [1] CRAN (R 3.6.2)
#>  knitr         1.27    2020-01-16 [1] CRAN (R 3.6.0)
#>  lifecycle     1.0.0   2021-02-15 [1] CRAN (R 3.6.2)
#>  magrittr      2.0.1   2020-11-17 [1] CRAN (R 3.6.2)
#>  pillar        1.6.2   2021-07-29 [1] CRAN (R 3.6.2)
#>  pkgconfig     2.0.3   2019-09-22 [1] CRAN (R 3.6.0)
#>  purrr         0.3.4   2020-04-17 [1] CRAN (R 3.6.2)
#>  Rcpp          1.0.4.6 2020-04-09 [1] CRAN (R 3.6.1)
#>  reprex        2.0.1   2021-08-05 [1] CRAN (R 3.6.2)
#>  rlang         0.4.10  2020-12-30 [1] CRAN (R 3.6.2)
#>  rmarkdown     2.1     2020-01-20 [1] CRAN (R 3.6.0)
#>  rstudioapi    0.13    2020-11-12 [1] CRAN (R 3.6.2)
#>  sessioninfo   1.1.1   2018-11-05 [1] CRAN (R 3.6.0)
#>  stringi       1.4.5   2020-01-11 [1] CRAN (R 3.6.0)
#>  stringr       1.4.0   2019-02-10 [1] CRAN (R 3.6.0)
#>  styler        1.5.1   2021-07-13 [1] CRAN (R 3.6.2)
#>  tibble        3.1.3   2021-07-23 [1] CRAN (R 3.6.2)
#>  utf8          1.1.4   2018-05-24 [1] CRAN (R 3.6.0)
#>  vctrs         0.3.8   2021-04-29 [1] CRAN (R 3.6.2)
#>  withr         2.4.2   2021-04-18 [1] CRAN (R 3.6.2)
#>  xfun          0.12    2020-01-13 [1] CRAN (R 3.6.0)
#>  yaml          2.2.0   2018-07-25 [1] CRAN (R 3.6.0)
#> 
#> [1] /Library/Frameworks/R.framework/Versions/3.6/Resources/library

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.