help with using save_pred = TRUE with control_race() and workflow_map()

I'm trying to use race_tune with workflowsets but am having an issue using save_pred = TRUE.

reprex below shows that i use save_pred = TRUE within control_race(), but when i run collect_predictions() i get an error: The '.predictions' column does not exist. Refit with the control argument 'save_pred = TRUE'

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.1.1
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
#> Warning: package 'dials' was built under R version 4.1.1
#> Warning: package 'ggplot2' was built under R version 4.1.1
#> Warning: package 'infer' was built under R version 4.1.1
#> Warning: package 'parsnip' was built under R version 4.1.1
#> Warning: package 'tune' was built under R version 4.1.1
#> Warning: package 'workflows' was built under R version 4.1.1
#> Warning: package 'workflowsets' was built under R version 4.1.1
#> Warning: package 'yardstick' was built under R version 4.1.1
library(discrim)
#> Warning: package 'discrim' was built under R version 4.1.1
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness
library(workflowsets)
library(finetune)
#> Warning: package 'finetune' was built under R version 4.1.1

data(parabolic)

set.seed(1)
split <- initial_split(parabolic)
train_set <- training(split)
test_set <- testing(split)
set.seed(2)
train_resamples <- bootstraps(train_set, times = 5)

mars_disc_spec <- 
  discrim_flexible(prod_degree = tune()) %>% 
  set_engine("earth")

reg_disc_sepc <- 
  discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>% 
  set_engine("klaR")

cart_spec <- 
  decision_tree(cost_complexity = tune(), min_n = tune()) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")


all_workflows <- 
  workflow_set(
    preproc = list("formula" = class ~ .),
    models = list(regularized = reg_disc_sepc, mars = mars_disc_spec, cart = cart_spec)
  )

class_metrics <- 
  metric_set(roc_auc, accuracy, sensitivity, specificity)

race_ctrl <-
  control_race(
    verbose = TRUE,
    allow_par = TRUE,
    save_pred = TRUE,
    parallel_over = "everything",
    save_workflow = TRUE
  )

doParallel::registerDoParallel()

wf_res <- 
  all_workflows %>% 
  workflow_map(fn = "tune_race_anova",
               resamples = train_resamples,
               grid = 10,
               metrics = class_metrics, 
               ctrl = race_ctrl
  )
#> Warning: The `...` are not used in this function but one or more objects were
#> passed: 'ctrl'
#> Warning: The `...` are not used in this function but one or more objects were
#> passed: 'ctrl'

#> Warning: The `...` are not used in this function but one or more objects were
#> passed: 'ctrl'

workflowsets::collect_predictions(wf_res)
#> Error: Problem with `mutate()` column `predictions`.
#> i `predictions = purrr::map(...)`.
#> x The `.predictions` column does not exist. Refit with the control argument `save_pred = TRUE` to save predictions.
Created on 2021-10-21 by the reprex package (v2.0.1)

Try using control = instead of ctrl = race_ctrl. I suspect that is the issue (based on the warnings)

That was it, thanks!

1 Like

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.