Using a custom metric during tuning/racing.

I am using tidymodels to optimize/evaluate models and have generated a custom metric that is used for virtual drug screening (enrichment factor at 1%). I've followed the vignettes/tutorials, and the metric is working after model fitting when used on the predictions from fitting to the hold out data, but I can't seem to get it to work during tuning/racing. Because it's class probability metric, I need to pass the class probability column in addition to the truth and estimate, but I can't get it to work in tune_grid() or workflow_map()/control_race(). Below is a reprex:

library(tidyverse)
library(tidymodels)
library(modeldata)

# metric ------------------------------------------------------------------

# get the correct event level
event_col <- function(xtab, event_level) {
  if (identical(event_level, "first")) {
    colnames(xtab)[[1]]
  } else {
    colnames(xtab)[[2]]
  }
}

# control what type of data can be used
finalize_estimator_internal.ef1 <- function(metric_dispatcher, x, estimator) {
  validate_estimator(estimator, estimator_override = "binary")
  
  if(!is.null(estimator)) {
    return(estimator)
  }
  
  lvls <- levels(x)
  
  if(length(lvls) > 2) {
    stop("A multiclass `truth` input was provided, but only `binary` is supported.")
  } 
  
  "binary"
}

# vector implementation
ef1_vec <- function(truth, estimate, estimate_val, estimator = NULL, event_level = "first", na_rm = TRUE, ...) {
  
  estimator <- finalize_estimator(truth, estimator, metric_class = "ef1")
  
  ef1_impl <- function(truth, estimate, estimate_val) {
    
    xtab <- table(estimate, truth)
    col <- event_col(xtab, event_level)
    
    N <- length(truth)
    
    A <- sum(truth == col)
    
    df <- bind_cols(truth, estimate, estimate_val) %>%
      rename(truth = 1, estimate = 2, estimate_val = 3) %>%
      arrange(-estimate_val) %>%
      slice_head(prop = 0.1) %>%
      group_by(truth) %>%
      tally()
    
    # there's probably a better way of getting the event name, but this works for now
    a <- filter(df, truth == col)$n
    n <- sum(df$n)
    
    (a / n) / (A / N)
    
  }
  
  metric_vec_template(
    metric_impl = ef1_impl,
    truth = truth,
    estimate = estimate,
    estimate_val = estimate_val,
    na_rm = na_rm,
    cls = "factor",
    estimator = estimator,
    ...
  )
  
}


# data frame implementation
ef1 <- function(data, ...) {
  UseMethod("ef1")
}

ef1 <- new_prob_metric(ef1, direction = "maximize")

ef1.data.frame <- function(data, truth, estimate, estimate_val, 
                           estimator = NULL, na_rm = TRUE, 
                           event_level = "first", ...) {
  metric_summarizer(
    metric_nm = "ef1",
    metric_fn = ef1_vec,
    data = data,
    truth = !! enquo(truth),
    estimate = !! enquo(estimate), 
    metric_fn_options = list(estimate_val = enquo(estimate_val)),
    estimator = estimator,
    na_rm = na_rm,
    event_level = event_level,
    ...
  )
}


# test metric -------------------------------------------------------------

data(two_class_example)

two_class_example %>% 
  ef1(truth, predicted, estimate_val = Class1)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 ef1     binary          1.94

# tune parameters ---------------------------------------------------------------
data(bivariate)

base_rec <-
  recipe(Class ~ ., data = bivariate_val)

rand_forest_ranger_spec <-
  rand_forest(mtry = tune(), min_n = tune(), trees = tune()) %>%
  set_engine('ranger') %>%
  set_mode('classification')

rf_grid <- grid_regular(finalize(mtry(), bivariate_train),
                        min_n(),
                        trees(),
                        levels = 3)

ef1_set <- metric_set(ef1)

rf_tune <- tune_grid(
  workflow(base_rec, rand_forest_ranger_spec),
  resamples = folds <- vfold_cv(bivariate_train,
                                v = 2,
                                strata = Class),
  grid = rf_grid,
  metrics = ef1_set(Class, .pred_class, estimate = .pred_class),
  control = control_grid(
    save_pred = TRUE,
    verbose = TRUE
  )
)
#> Error:
#> ! In metric: `ef1`
#> object 'Class' not found

Created on 2022-04-13 by the reprex package (v2.0.1)

I think this is just a slight misunderstanding with how the yardstick probability metrics are supposed to work.

First, note that you don't actually need the factor estimate column in your metric calculation. You really only need the class probability predictions for the event of interest (here you call that estimate_val).

That's good news, because yardstick doesn't actually support supplying both the hard class predictions through estimate and the soft class probability predictions at the same time.

We also need to refactor your metric function to take ... rather than estimate_val. Most class probability functions take ... because they support truth columns with >2 levels, which requires supplying multiple columns of class probabilities through .... You seem to want to only support binary truth columns, which is fine, but you still need to conform to the standard prob-metric signature to make this work.

So it should look something like this:

(Also note that it you want to use this with parallel tuning, then you'll have to put your metric in a true R package and supply it to the pkgs argument of tune::control_grid(pks = ) to ensure the metric function is available on the workers. See Custom metrics don't seem to work with parallel processing · Issue #479 · tidymodels/tune · GitHub)

library(tidyverse)
library(tidymodels)
library(modeldata)

# metric ------------------------------------------------------------------

# get the correct event level
truth_event_level <- function(truth, event_level) {
  if (identical(event_level, "first")) {
    levels(truth)[[1]]
  } else {
    levels(truth)[[2]]
  }
}

# control what type of data can be used
finalize_estimator_internal.ef1 <- function(metric_dispatcher, x, estimator) {
  validate_estimator(estimator, estimator_override = "binary")
  
  if(!is.null(estimator)) {
    return(estimator)
  }
  
  lvls <- levels(x)
  
  if(length(lvls) > 2) {
    stop("A multiclass `truth` input was provided, but only `binary` is supported.")
  } 
  
  "binary"
}

# vector implementation
ef1_vec <- function(truth, 
                    estimate,
                    estimator = NULL, 
                    event_level = "first", 
                    na_rm = TRUE) {
  estimator <- finalize_estimator(truth, estimator, metric_class = "ef1")
  
  ef1_impl <- function(truth, estimate) {
    col <- truth_event_level(truth, event_level)
    
    N <- length(truth)
    
    A <- sum(truth == col)
    
    df <- bind_cols(truth = truth, estimate = estimate) %>%
      arrange(-estimate) %>%
      slice_head(prop = 0.1) %>%
      group_by(truth) %>%
      tally()
    
    # there's probably a better way of getting the event name, but this works for now
    a <- filter(df, truth == col)$n
    n <- sum(df$n)
    
    (a / n) / (A / N)
  }
  
  metric_vec_template(
    metric_impl = ef1_impl,
    truth = truth,
    estimate = estimate,
    estimator = estimator,
    na_rm = na_rm,
    cls = c("factor", "numeric")
  )
}


# data frame implementation
ef1 <- function(data, ...) {
  UseMethod("ef1")
}
ef1 <- new_prob_metric(ef1, direction = "maximize")

ef1.data.frame <- function(data, 
                           truth, 
                           ...,
                           estimator = NULL, 
                           na_rm = TRUE, 
                           event_level = "first") {
  estimate <- dots_to_estimate(data, !!! enquos(...))
  
  metric_summarizer(
    metric_nm = "ef1",
    metric_fn = ef1_vec,
    data = data,
    truth = !!enquo(truth),
    estimate = !!estimate, 
    estimator = estimator,
    na_rm = na_rm,
    event_level = event_level
  )
}


# test metric -------------------------------------------------------------

data(two_class_example)

two_class_example %>% 
  ef1(truth, Class1)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 ef1     binary          1.94

# tune parameters ---------------------------------------------------------------
data(bivariate)

base_rec <-
  recipe(Class ~ ., data = bivariate_val)

rand_forest_ranger_spec <-
  rand_forest(mtry = tune(), min_n = tune(), trees = tune()) %>%
  set_engine('ranger') %>%
  set_mode('classification')

rf_grid <- grid_regular(finalize(mtry(), bivariate_train),
                        min_n(),
                        trees(),
                        levels = 3)

ef1_set <- metric_set(ef1)

rf_tune <- tune_grid(
  workflow(base_rec, rand_forest_ranger_spec),
  resamples = vfold_cv(bivariate_train, v = 2, strata = Class),
  grid = rf_grid,
  metrics = ef1_set,
  control = control_grid(
    save_pred = TRUE,
    verbose = FALSE
  )
)
#> <SOME NOTES GET PRINTED HERE>

rf_tune
#> # Tuning results
#> # 2-fold cross-validation using stratification 
#> # A tibble: 2 × 5
#>   splits            id    .metrics          .notes           .predictions
#>   <list>            <chr> <list>            <list>           <list>      
#> 1 <split [504/505]> Fold1 <tibble [27 × 7]> <tibble [9 × 3]> <tibble>    
#> 2 <split [505/504]> Fold2 <tibble [27 × 7]> <tibble [9 × 3]> <tibble>    
#> 
#> There were issues with some computations:
#> 
#>   - Warning(s) x18: 3 columns were requested but there were 2 predictors in the data....
#> 
#> Use `collect_notes(object)` for more information.

rf_tune$.metrics[[1]]
#> # A tibble: 27 × 7
#>     mtry trees min_n .metric .estimator .estimate .config              
#>    <int> <int> <int> <chr>   <chr>          <dbl> <chr>                
#>  1     1     1     2 ef1     binary          1.24 Preprocessor1_Model01
#>  2     2     1     2 ef1     binary          1.18 Preprocessor1_Model02
#>  3     3     1     2 ef1     binary          1.33 Preprocessor1_Model03
#>  4     1     1    21 ef1     binary          1.40 Preprocessor1_Model04
#>  5     2     1    21 ef1     binary          1.33 Preprocessor1_Model05
#>  6     3     1    21 ef1     binary          1.37 Preprocessor1_Model06
#>  7     1     1    40 ef1     binary          1.43 Preprocessor1_Model07
#>  8     2     1    40 ef1     binary          1.46 Preprocessor1_Model08
#>  9     3     1    40 ef1     binary          1.43 Preprocessor1_Model09
#> 10     1  1000     2 ef1     binary          1.52 Preprocessor1_Model10
#> # … with 17 more rows

Created on 2022-04-14 by the reprex package (v2.0.1)

Thanks so much @davis. Thanks for pointing out that I don't actually need estimate and clarifying how the yardstick prob metrics work. I will work on packaging for parallelization but hopefully won't into any problems.

Thanks again!

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.