Hey there!
tune_grid
may not return raw class probabilities if your only chosen metric is a class metric. See below:
library(tidymodels)
x <- tibble(
class = sample(c("class_1", "class_2"), 100, replace = TRUE),
a = rnorm(100),
b = rnorm(100)
)
spec_lasso <-
logistic_reg(engine = 'glmnet', penalty = tune(), mixture = 1)
rec <-
recipe(class ~ ., x) %>%
step_normalize(all_numeric_predictors())
# tune_grid with only a class metric, accuracy ---------------------------------
res <- tune_grid(
spec_lasso,
preprocessor = rec,
resamples = vfold_cv(x, 2),
grid = 2,
metrics = metric_set(accuracy),
control = control_grid(save_pred = TRUE)
)
# only getting hard class predictions :(
res$.predictions[[1]]
#> # A tibble: 100 Ă— 5
#> .pred_class .row penalty class .config
#> <fct> <int> <dbl> <fct> <chr>
#> 1 class_1 1 0.0000000120 class_2 Preprocessor1_Model1
#> 2 class_1 4 0.0000000120 class_1 Preprocessor1_Model1
#> 3 class_1 5 0.0000000120 class_2 Preprocessor1_Model1
#> 4 class_1 7 0.0000000120 class_1 Preprocessor1_Model1
#> 5 class_1 8 0.0000000120 class_2 Preprocessor1_Model1
#> 6 class_1 10 0.0000000120 class_1 Preprocessor1_Model1
#> 7 class_1 11 0.0000000120 class_2 Preprocessor1_Model1
#> 8 class_1 15 0.0000000120 class_1 Preprocessor1_Model1
#> 9 class_1 18 0.0000000120 class_1 Preprocessor1_Model1
#> 10 class_1 22 0.0000000120 class_1 Preprocessor1_Model1
#> # … with 90 more rows
# do so with both a class and prob metric -------------------------------------
res2 <- tune_grid(
spec_lasso,
preprocessor = rec,
resamples = vfold_cv(x, 2),
grid = 2,
# see here--add a prob_metric roc_auc
metrics = metric_set(accuracy, roc_auc),
control = control_grid(save_pred = TRUE)
)
# see .pred_class_1 and .pred_class_2
res2$.predictions[[1]]
#> # A tibble: 100 Ă— 7
#> .pred_class .row penalty .pred_class_1 .pred_class_2 class .config
#> <fct> <int> <dbl> <dbl> <dbl> <fct> <chr>
#> 1 class_1 1 0.000000173 0.624 0.376 class_2 Preprocess…
#> 2 class_2 4 0.000000173 0.455 0.545 class_1 Preprocess…
#> 3 class_1 6 0.000000173 0.720 0.280 class_1 Preprocess…
#> 4 class_1 10 0.000000173 0.584 0.416 class_1 Preprocess…
#> 5 class_1 12 0.000000173 0.691 0.309 class_1 Preprocess…
#> 6 class_2 19 0.000000173 0.356 0.644 class_1 Preprocess…
#> 7 class_2 21 0.000000173 0.417 0.583 class_2 Preprocess…
#> 8 class_2 22 0.000000173 0.338 0.662 class_1 Preprocess…
#> 9 class_2 23 0.000000173 0.341 0.659 class_1 Preprocess…
#> 10 class_2 24 0.000000173 0.481 0.519 class_2 Preprocess…
#> # … with 90 more rows
Created on 2022-05-03 by the reprex package (v2.0.1)
I can't be sure where your issue is coming from without access to your rf_wf
and auction_folds
objects, though. If this doesn't address your problem, could you please provide a minimal reprex (reproducible example)? A reprex will help me troubleshoot and fix your issue more quickly.