ROC Curves with Tidymodels Workflow Sets

Hello, does anyone know how to get predicted class probabilities when using collect_predictions with a workflow SET in tidymodels? I'm trying to plot ROC curves and I can use collect_predictions, but I only get the predicted class back. You can't use type = "prob" in the collect_predictions function, so I'm struggling with how to get these class probabilities.

Thanks for the post, @ddelzell!

Do you happen to be setting your metrics explicitly? If so, have you made sure to includes a metric that acts on class probabilities in your metric set? e.g.

class(roc_auc)
#> [1] "prob_metric" "metric"      "function"

rather than

class(accuracy)
#> [1] "class_metric" "metric"       "function"

Created on 2022-08-24 by the reprex package (v2.0.1)

If that doesn't do the trick for you, could you please supply a reprex demonstrating your issue? Here's some example code showing generating class probabilities with collect_predictions with a workflow set to get you started.

library(tidymodels)

data(two_class_dat, package = "modeldata")

set.seed(1)
folds <- vfold_cv(two_class_dat, v = 5)

decision_tree_rpart_spec <-
  decision_tree(min_n = tune(), cost_complexity = tune()) %>%
  set_engine('rpart') %>%
  set_mode('classification')

logistic_reg_glm_spec <-
  logistic_reg() %>%
  set_engine('glm')

mars_earth_spec <-
  mars(prod_degree = tune()) %>%
  set_engine('earth') %>%
  set_mode('classification')

yj_recipe <-
  recipe(Class ~ ., data = two_class_dat) %>%
  step_YeoJohnson(A, B)

two_class_set <-
  workflow_set(
    preproc = list(none = Class ~ A + B, yj_trans = yj_recipe),
    models = list(cart = decision_tree_rpart_spec, glm = logistic_reg_glm_spec,
                  mars = mars_earth_spec)
  )

two_class_res <-
  two_class_set %>%
  workflow_map(
    resamples = folds,
    grid = 10,
    seed = 2,
    verbose = TRUE,
    control = control_grid(save_workflow = TRUE, save_pred = TRUE)
  )
#> i 1 of 6 tuning:     none_cart
#> ✔ 1 of 6 tuning:     none_cart (4.1s)
#> i    No tuning parameters. `fit_resamples()` will be attempted
#> i 2 of 6 resampling: none_glm
#> ✔ 2 of 6 resampling: none_glm (426ms)
#> i 3 of 6 tuning:     none_mars
#> ! Fold5: preprocessor 1/1, model 1/2: glm.fit: fitted probabilities numerically 0 or 1 occurred
#> ✔ 3 of 6 tuning:     none_mars (1.1s)
#> i 4 of 6 tuning:     yj_trans_cart
#> ✔ 4 of 6 tuning:     yj_trans_cart (4.3s)
#> i    No tuning parameters. `fit_resamples()` will be attempted
#> i 5 of 6 resampling: yj_trans_glm
#> ✔ 5 of 6 resampling: yj_trans_glm (555ms)
#> i 6 of 6 tuning:     yj_trans_mars
#> ! Fold4: preprocessor 1/1, model 2/2: glm.fit: fitted probabilities numerically 0 or 1 occurred
#> ✔ 6 of 6 tuning:     yj_trans_mars (1.1s)
collect_predictions(two_class_res)
#> # A tibble: 20,566 × 9
#>    wflow_id  .config           preproc model  .row Class .pred…¹ .pred…² .pred…³
#>    <chr>     <chr>             <chr>   <chr> <int> <fct>   <dbl>   <dbl> <fct>  
#>  1 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0.08    0.92  Class2 
#>  2 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0.2     0.8   Class2 
#>  3 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0.208   0.792 Class2 
#>  4 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0.103   0.897 Class2 
#>  5 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0.08    0.92  Class2 
#>  6 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0.103   0.897 Class2 
#>  7 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0       1     Class2 
#>  8 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0.229   0.771 Class2 
#>  9 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0.229   0.771 Class2 
#> 10 none_cart Preprocessor1_Mo… formula deci…     1 Clas…   0.216   0.784 Class2 
#> # … with 20,556 more rows, and abbreviated variable names ¹​.pred_Class1,
#> #   ²​.pred_Class2, ³​.pred_class

Created on 2022-08-24 by the reprex package (v2.0.1)

Simon, thank you, that worked! I was explicitly creating a metric set with sens, spec, accuracy. As soon as I added 'roc_auc' I got the class probabilities. Thank you for the quick suggestion!

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.