Predict a class using a threshold different than the 0.5 default with tidymodels

  1. Is there any way I can specify in {parsnip}/{yardstick} ({tidymodels} for short), in a standard two-class classification model flow, I want to predict a class using a threshold on score ("probability") different from the 0.5 default?

  2. Is there a way to incorporate this decision into the training flow? I.e. learn the best threshold from the the test ROC (the thresh that would maximize the AUC)

Note: I know how to do this manually, I'm asking about a parameter which allows this seamlessly in the flow

Example with simulated data:

library(tidyverse)
library(tidymodels)

n <- 10000
x1 <- runif(n)
x2 <- runif(n)
t <- 1 + 2 * x1 + 3 * x2
y <- rbinom(n, 1, 1 / (1 + exp(-t)))

my_data <- tibble(y = factor(y), x1 = x1, x2 = x2)

my_rec <- 
  recipe(y ~ ., data = my_data) %>%
  step_center(all_predictors()) %>%
  step_scale(all_predictors()) %>%
  prep(training = my_data)

glm_mod <- logistic_reg() %>%
  set_engine("glm") %>%
  fit(y ~ ., bake(my_rec, my_data))

glm_res <- predict(glm_mod, bake(my_rec, my_data), type = "prob") %>%
  bind_cols(predict(glm_mod, bake(my_rec, my_data), type = "class")) %>%
  bind_cols(y = my_data$y)

I'm guessing the simplest thing to my 1st question would be a parameter to predict when the type is "class", something like "threshold" or "cutoff" that would receive say 0.8 instead of 0.5. But this is needed also in {yardstick} with just about any two-class metric (spec, sens, precision). Anyway I haven't found one in the docs or in any demo. Is the current solution to go manual with the predicted scores?

About the 2nd question I have no idea...

This is planned for the next major version of workflows; you'll be able to specify a specific threshold that will just work or to say that you want to tune it.

The bad news is that we are spending 1-2 months on documentation currently.

2 Likes

Documentation is important, thanks Max!

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.