I think this is just a slight misunderstanding with how the yardstick probability metrics are supposed to work.
First, note that you don't actually need the factor estimate column in your metric calculation. You really only need the class probability predictions for the event of interest (here you call that estimate_val).
That's good news, because yardstick doesn't actually support supplying both the hard class predictions through estimate and the soft class probability predictions at the same time.
We also need to refactor your metric function to take ... rather than estimate_val. Most class probability functions take ... because they support truth columns with >2 levels, which requires supplying multiple columns of class probabilities through .... You seem to want to only support binary truth columns, which is fine, but you still need to conform to the standard prob-metric signature to make this work.
So it should look something like this:
(Also note that it you want to use this with parallel tuning, then you'll have to put your metric in a true R package and supply it to the pkgs argument of tune::control_grid(pks = ) to ensure the metric function is available on the workers. See Custom metrics don't seem to work with parallel processing · Issue #479 · tidymodels/tune · GitHub)
library(tidyverse)
library(tidymodels)
library(modeldata)
# metric ------------------------------------------------------------------
# get the correct event level
truth_event_level <- function(truth, event_level) {
if (identical(event_level, "first")) {
levels(truth)[[1]]
} else {
levels(truth)[[2]]
}
}
# control what type of data can be used
finalize_estimator_internal.ef1 <- function(metric_dispatcher, x, estimator) {
validate_estimator(estimator, estimator_override = "binary")
if(!is.null(estimator)) {
return(estimator)
}
lvls <- levels(x)
if(length(lvls) > 2) {
stop("A multiclass `truth` input was provided, but only `binary` is supported.")
}
"binary"
}
# vector implementation
ef1_vec <- function(truth,
estimate,
estimator = NULL,
event_level = "first",
na_rm = TRUE) {
estimator <- finalize_estimator(truth, estimator, metric_class = "ef1")
ef1_impl <- function(truth, estimate) {
col <- truth_event_level(truth, event_level)
N <- length(truth)
A <- sum(truth == col)
df <- bind_cols(truth = truth, estimate = estimate) %>%
arrange(-estimate) %>%
slice_head(prop = 0.1) %>%
group_by(truth) %>%
tally()
# there's probably a better way of getting the event name, but this works for now
a <- filter(df, truth == col)$n
n <- sum(df$n)
(a / n) / (A / N)
}
metric_vec_template(
metric_impl = ef1_impl,
truth = truth,
estimate = estimate,
estimator = estimator,
na_rm = na_rm,
cls = c("factor", "numeric")
)
}
# data frame implementation
ef1 <- function(data, ...) {
UseMethod("ef1")
}
ef1 <- new_prob_metric(ef1, direction = "maximize")
ef1.data.frame <- function(data,
truth,
...,
estimator = NULL,
na_rm = TRUE,
event_level = "first") {
estimate <- dots_to_estimate(data, !!! enquos(...))
metric_summarizer(
metric_nm = "ef1",
metric_fn = ef1_vec,
data = data,
truth = !!enquo(truth),
estimate = !!estimate,
estimator = estimator,
na_rm = na_rm,
event_level = event_level
)
}
# test metric -------------------------------------------------------------
data(two_class_example)
two_class_example %>%
ef1(truth, Class1)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 ef1 binary 1.94
# tune parameters ---------------------------------------------------------------
data(bivariate)
base_rec <-
recipe(Class ~ ., data = bivariate_val)
rand_forest_ranger_spec <-
rand_forest(mtry = tune(), min_n = tune(), trees = tune()) %>%
set_engine('ranger') %>%
set_mode('classification')
rf_grid <- grid_regular(finalize(mtry(), bivariate_train),
min_n(),
trees(),
levels = 3)
ef1_set <- metric_set(ef1)
rf_tune <- tune_grid(
workflow(base_rec, rand_forest_ranger_spec),
resamples = vfold_cv(bivariate_train, v = 2, strata = Class),
grid = rf_grid,
metrics = ef1_set,
control = control_grid(
save_pred = TRUE,
verbose = FALSE
)
)
#> <SOME NOTES GET PRINTED HERE>
rf_tune
#> # Tuning results
#> # 2-fold cross-validation using stratification
#> # A tibble: 2 × 5
#> splits id .metrics .notes .predictions
#> <list> <chr> <list> <list> <list>
#> 1 <split [504/505]> Fold1 <tibble [27 × 7]> <tibble [9 × 3]> <tibble>
#> 2 <split [505/504]> Fold2 <tibble [27 × 7]> <tibble [9 × 3]> <tibble>
#>
#> There were issues with some computations:
#>
#> - Warning(s) x18: 3 columns were requested but there were 2 predictors in the data....
#>
#> Use `collect_notes(object)` for more information.
rf_tune$.metrics[[1]]
#> # A tibble: 27 × 7
#> mtry trees min_n .metric .estimator .estimate .config
#> <int> <int> <int> <chr> <chr> <dbl> <chr>
#> 1 1 1 2 ef1 binary 1.24 Preprocessor1_Model01
#> 2 2 1 2 ef1 binary 1.18 Preprocessor1_Model02
#> 3 3 1 2 ef1 binary 1.33 Preprocessor1_Model03
#> 4 1 1 21 ef1 binary 1.40 Preprocessor1_Model04
#> 5 2 1 21 ef1 binary 1.33 Preprocessor1_Model05
#> 6 3 1 21 ef1 binary 1.37 Preprocessor1_Model06
#> 7 1 1 40 ef1 binary 1.43 Preprocessor1_Model07
#> 8 2 1 40 ef1 binary 1.46 Preprocessor1_Model08
#> 9 3 1 40 ef1 binary 1.43 Preprocessor1_Model09
#> 10 1 1000 2 ef1 binary 1.52 Preprocessor1_Model10
#> # … with 17 more rows
Created on 2022-04-14 by the reprex package (v2.0.1)