Hello, does anyone know how to get predicted class probabilities when using collect_predictions with a workflow SET in tidymodels? I'm trying to plot ROC curves and I can use collect_predictions, but I only get the predicted class back. You can't use type = "prob" in the collect_predictions function, so I'm struggling with how to get these class probabilities.
Thanks for the post, @ddelzell!
Do you happen to be setting your metrics explicitly? If so, have you made sure to includes a metric that acts on class probabilities in your metric set? e.g.
class(roc_auc)
#> [1] "prob_metric" "metric" "function"
rather than
class(accuracy)
#> [1] "class_metric" "metric" "function"
Created on 2022-08-24 by the reprex package (v2.0.1)
If that doesn't do the trick for you, could you please supply a reprex demonstrating your issue? Here's some example code showing generating class probabilities with collect_predictions
with a workflow set to get you started.
library(tidymodels)
data(two_class_dat, package = "modeldata")
set.seed(1)
folds <- vfold_cv(two_class_dat, v = 5)
decision_tree_rpart_spec <-
decision_tree(min_n = tune(), cost_complexity = tune()) %>%
set_engine('rpart') %>%
set_mode('classification')
logistic_reg_glm_spec <-
logistic_reg() %>%
set_engine('glm')
mars_earth_spec <-
mars(prod_degree = tune()) %>%
set_engine('earth') %>%
set_mode('classification')
yj_recipe <-
recipe(Class ~ ., data = two_class_dat) %>%
step_YeoJohnson(A, B)
two_class_set <-
workflow_set(
preproc = list(none = Class ~ A + B, yj_trans = yj_recipe),
models = list(cart = decision_tree_rpart_spec, glm = logistic_reg_glm_spec,
mars = mars_earth_spec)
)
two_class_res <-
two_class_set %>%
workflow_map(
resamples = folds,
grid = 10,
seed = 2,
verbose = TRUE,
control = control_grid(save_workflow = TRUE, save_pred = TRUE)
)
#> i 1 of 6 tuning: none_cart
#> ✔ 1 of 6 tuning: none_cart (4.1s)
#> i No tuning parameters. `fit_resamples()` will be attempted
#> i 2 of 6 resampling: none_glm
#> ✔ 2 of 6 resampling: none_glm (426ms)
#> i 3 of 6 tuning: none_mars
#> ! Fold5: preprocessor 1/1, model 1/2: glm.fit: fitted probabilities numerically 0 or 1 occurred
#> ✔ 3 of 6 tuning: none_mars (1.1s)
#> i 4 of 6 tuning: yj_trans_cart
#> ✔ 4 of 6 tuning: yj_trans_cart (4.3s)
#> i No tuning parameters. `fit_resamples()` will be attempted
#> i 5 of 6 resampling: yj_trans_glm
#> ✔ 5 of 6 resampling: yj_trans_glm (555ms)
#> i 6 of 6 tuning: yj_trans_mars
#> ! Fold4: preprocessor 1/1, model 2/2: glm.fit: fitted probabilities numerically 0 or 1 occurred
#> ✔ 6 of 6 tuning: yj_trans_mars (1.1s)
collect_predictions(two_class_res)
#> # A tibble: 20,566 × 9
#> wflow_id .config preproc model .row Class .pred…¹ .pred…² .pred…³
#> <chr> <chr> <chr> <chr> <int> <fct> <dbl> <dbl> <fct>
#> 1 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0.08 0.92 Class2
#> 2 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0.2 0.8 Class2
#> 3 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0.208 0.792 Class2
#> 4 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0.103 0.897 Class2
#> 5 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0.08 0.92 Class2
#> 6 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0.103 0.897 Class2
#> 7 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0 1 Class2
#> 8 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0.229 0.771 Class2
#> 9 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0.229 0.771 Class2
#> 10 none_cart Preprocessor1_Mo… formula deci… 1 Clas… 0.216 0.784 Class2
#> # … with 20,556 more rows, and abbreviated variable names ¹.pred_Class1,
#> # ².pred_Class2, ³.pred_class
Created on 2022-08-24 by the reprex package (v2.0.1)
Simon, thank you, that worked! I was explicitly creating a metric set with sens, spec, accuracy. As soon as I added 'roc_auc' I got the class probabilities. Thank you for the quick suggestion!
This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.
If you have a query related to it or one of the replies, start a new topic and refer back with a link.