fitting each model from workflow_map results

Hi guys,

I am using tidymodels to training data with the following codes. How can I fit each model from workflow_set object? Now, I have to extract model one by one.

> run_chi_models[[i]] <- workflow_set(preproc = list(recipe1 = run_recipe1, recipe2 = run_recipe2), 
                                                                         models  = list(lda = lda_reg_model),
                                                                         cross = TRUE)

>run_tune_bayes_models_res[[i]]  <- run_chi_models[[i]] %>% 
    workflow_map("tune_bayes", resamples = run_k_cv, initial = 50 , iter = 100,
                 metrics  = metric_set(accuracy, roc_auc), verbose = T,  
                 control = control_bayes(save_workflow = FALSE, save_pred = FALSE)
    ) 
> recipe1_lda[[i]] <- finalize_workflow(extract_workflow(run_tune_bayes_models_res[[i]],  "recipe1_lda"),
                                      select_best(extract_workflow_set_result(run_tune_bayes_models_res[[i]], "recipe1_lda"), metric = "accuracy")) %>%
    last_fit(initial_splits,  metrics  = metric_set(accuracy))

> recipe2_lda[[i]] <- finalize_workflow(extract_workflow(run_tune_bayes_models_res[[i]],  "recipe2_lda"),
                                      select_best(extract_workflow_set_result(run_tune_bayes_models_res[[i]], "recipe2_lda"), metric = "accuracy")) %>%
    last_fit(initial_splits,  metrics  = metric_set(accuracy))

The devel version of tune has a fit_best() function that can be used on the tuning results.

Hi Max,

Thanks for your reply. I am trying to training model with different random seed as below codes. However, I met the error. It seems the fit_best cannot hand list data.

library(tidymodels)
#> Warning: package 'ggplot2' was built under R version 4.2.3
#> Warning: package 'recipes' was built under R version 4.2.3
#> Warning: package 'workflows' was built under R version 4.2.3
library(themis)
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness

data(iris)


run_chi_models <- list()
run_tune_bayes_models_res <- list()
lda_fit <- list()

n_run <- 2

for (i in 1:n_run) {
  # set random set
  set.seed(i)
  random_seed <- sample.int(10000, 1)
  set.seed(random_seed)
  iris_split <- initial_split(iris, prop = 0.80)
  iris_split
  iris_train <- training(iris_split)
  iris_test <- testing(iris_split)
  run_k_cv <- vfold_cv(iris_train, v = 2, repeats = 1, strata = "Species", pool = 0.2)
  
  # define two recipes
  run_recipe1 <-
    recipe(Species ~ .,
           data = iris_train)
  run_recipe2 <-
    recipe(Species ~ .,
           data = iris_train) %>%
    step_smote()
  
  # define two models
  lda_reg_model <- discrim_linear(penalty = tune(), regularization_method = "shrink_cov") %>% 
    set_engine("mda", keep.fitted = TRUE) %>% 
    set_mode("classification") 
  
  svm_rbf_model <- svm_rbf(cost = tune(), rbf_sigma = tune()) %>%
    set_mode("classification")  %>%
    set_engine("kernlab")
  
  #generate a set of workflow  
  run_chi_models[[i]] <- workflow_set(preproc = list(recipe1 = run_recipe1, recipe2 = run_recipe2), 
                                      models  = list(lda = lda_reg_model, svm = svm_rbf_model),
                                      cross = TRUE)
  
  # tuning 
  run_tune_bayes_models_res[[i]]  <- run_chi_models[[i]] %>% 
    workflow_map("tune_bayes", resamples = run_k_cv, initial = 5 , iter = 10,
                 metrics  = metric_set(accuracy, roc_auc), verbose = T,  
                 control = control_bayes(save_workflow = TRUE, save_pred = FALSE))
  
  # fit model
  
  # How can I save each workflow using workflow ID, such as
  # recipe1_lda[[i]] <- fit_best(run_tune_bayes_models_res[[i]], verbose = TRUE)
  lda_fit[[i]] <- fit_best(run_tune_bayes_models_res[[i]], verbose = TRUE)                
  
}
#> i 1 of 4 tuning:     recipe1_lda
#> ! All of the accuracy values were identical. The Gaussian process model
#>   cannot be fit to the data. Try expanding the range of the tuning
#>   parameters.
#> → A | error:   Infinite values of the Deviance Function, 
#>                            unable to find optimum parameters 
#> 
#> There were issues with some computations   A: x1                                                 ✖ Optimization stopped prematurely; returning current results.
#> There were issues with some computations   A: x1There were issues with some computations   A: x1
#> ✔ 1 of 4 tuning:     recipe1_lda (2.4s)
#> i 2 of 4 tuning:     recipe1_svm
#> ✔ 2 of 4 tuning:     recipe1_svm (42.9s)
#> i 3 of 4 tuning:     recipe2_lda
#> ! All of the accuracy values were identical. The Gaussian process model
#>   cannot be fit to the data. Try expanding the range of the tuning
#>   parameters.
#> → A | error:   Infinite values of the Deviance Function, 
#>                            unable to find optimum parameters 
#>                
#> There were issues with some computations   A: x1                                                 ✖ Optimization stopped prematurely; returning current results.
#> There were issues with some computations   A: x1There were issues with some computations   A: x1
#> ✔ 3 of 4 tuning:     recipe2_lda (2.1s)
#> i 4 of 4 tuning:     recipe2_svm
#> ✔ 4 of 4 tuning:     recipe2_svm (41.2s)
#> Error in `fit_best()`:
#> ! There is no `fit_best()` method for an object with classes
#>   `workflow_set`, `tbl_df`, `tbl`, and `data.frame`.

#> Backtrace:
#>     ▆
#>  1. ├─tune::fit_best(run_tune_bayes_models_res[[i]], verbose = TRUE)
#>  2. └─tune:::fit_best.default(run_tune_bayes_models_res[[i]], verbose = TRUE)
#>  3.   └─cli::cli_abort("There is no `fit_best()` method for an object with \\\n     {cli::qty(cls)} class{?es} {.var {cls}}.")
#>  4.     └─rlang::abort(...)

Created on 2023-03-23 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.2 (2022-10-31 ucrt)
#>  os       Windows 10 x64 (build 22000)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  Chinese (Simplified)_China.utf8
#>  ctype    Chinese (Simplified)_China.utf8
#>  tz       Asia/Taipei
#>  date     2023-03-23
#>  pandoc   2.19.2 @ D:/Program Files/RStudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  broom        * 1.0.2      2022-12-15 [1] CRAN (R 4.2.2)
#>  class        * 7.3-20     2022-01-16 [1] CRAN (R 4.2.2)
#>  cli            3.5.0      2022-12-20 [1] CRAN (R 4.2.2)
#>  codetools      0.2-18     2020-11-04 [1] CRAN (R 4.2.2)
#>  colorspace     2.0-3      2022-02-21 [1] CRAN (R 4.2.2)
#>  dials        * 1.1.0      2022-11-04 [1] CRAN (R 4.2.2)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.2.2)
#>  digest         0.6.31     2022-12-11 [1] CRAN (R 4.2.2)
#>  discrim      * 1.0.0      2022-06-23 [1] CRAN (R 4.2.2)
#>  dplyr        * 1.1.0      2023-01-29 [1] CRAN (R 4.2.2)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.2.2)
#>  evaluate       0.20       2023-01-17 [1] CRAN (R 4.2.2)
#>  fansi          1.0.4      2023-01-22 [1] CRAN (R 4.2.2)
#>  fastmap        1.1.0      2021-01-25 [1] CRAN (R 4.2.2)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.2.2)
#>  fs             1.5.2      2021-12-08 [1] CRAN (R 4.2.2)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.2.2)
#>  future         1.32.0     2023-03-07 [1] CRAN (R 4.2.3)
#>  future.apply   1.10.0     2022-11-05 [1] CRAN (R 4.2.2)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.2.2)
#>  ggplot2      * 3.4.1      2023-02-10 [1] CRAN (R 4.2.3)
#>  globals        0.16.2     2022-11-21 [1] CRAN (R 4.2.2)
#>  glue           1.6.2      2022-02-24 [1] CRAN (R 4.2.2)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.2.2)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.2.2)
#>  gtable         0.3.1      2022-09-01 [1] CRAN (R 4.2.2)
#>  hardhat        1.2.0      2022-06-30 [1] CRAN (R 4.2.2)
#>  htmltools      0.5.4      2022-12-07 [1] CRAN (R 4.2.2)
#>  infer        * 1.0.4      2022-12-02 [1] CRAN (R 4.2.2)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.2.3)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.2.2)
#>  kernlab      * 0.9-31     2022-06-09 [1] CRAN (R 4.2.0)
#>  knitr          1.42       2023-01-25 [1] CRAN (R 4.2.2)
#>  lattice        0.20-45    2021-09-22 [1] CRAN (R 4.2.2)
#>  lava           1.7.2.1    2023-02-27 [1] CRAN (R 4.2.3)
#>  lhs            1.1.6      2022-12-17 [1] CRAN (R 4.2.3)
#>  lifecycle      1.0.3      2022-10-07 [1] CRAN (R 4.2.2)
#>  listenv        0.9.0      2022-12-16 [1] CRAN (R 4.2.2)
#>  lubridate      1.9.2      2023-02-10 [1] CRAN (R 4.2.3)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.1.3)
#>  MASS           7.3-58.1   2022-08-03 [1] CRAN (R 4.2.2)
#>  Matrix         1.5-3      2022-11-11 [1] CRAN (R 4.2.2)
#>  mda          * 0.5-3      2022-05-05 [1] CRAN (R 4.2.2)
#>  modeldata    * 1.0.1      2022-09-06 [1] CRAN (R 4.2.2)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.2.2)
#>  nnet           7.3-18     2022-09-28 [1] CRAN (R 4.2.2)
#>  parallelly     1.34.0     2023-01-13 [1] CRAN (R 4.2.2)
#>  parsnip      * 1.0.4.9004 2023-03-23 [1] Github (tidymodels/parsnip@34f74fe)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.2.2)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.2.2)
#>  prettyunits    1.1.1      2020-01-24 [1] CRAN (R 4.2.2)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.2.2)
#>  purrr        * 1.0.1.9000 2023-03-12 [1] Github (tidyverse/purrr@fd5a732)
#>  R.cache        0.16.0     2022-07-21 [1] CRAN (R 4.2.2)
#>  R.methodsS3    1.8.2      2022-06-13 [1] CRAN (R 4.2.0)
#>  R.oo           1.25.0     2022-06-12 [1] CRAN (R 4.2.0)
#>  R.utils        2.12.2     2022-11-11 [1] CRAN (R 4.2.2)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.2.2)
#>  Rcpp           1.0.9      2022-07-08 [1] CRAN (R 4.2.2)
#>  recipes      * 1.0.5      2023-02-20 [1] CRAN (R 4.2.3)
#>  reprex         2.0.2      2022-08-17 [1] CRAN (R 4.2.3)
#>  rlang          1.0.6      2022-09-24 [1] CRAN (R 4.2.2)
#>  rmarkdown      2.19       2022-12-15 [1] CRAN (R 4.2.2)
#>  ROSE           0.0-4      2021-06-14 [1] CRAN (R 4.2.2)
#>  rpart          4.1.19     2022-10-21 [1] CRAN (R 4.2.2)
#>  rsample      * 1.1.1.9000 2023-03-23 [1] Github (tidymodels/rsample@43b3e2d)
#>  rstudioapi     0.14       2022-08-22 [1] CRAN (R 4.2.2)
#>  scales       * 1.2.1      2022-08-20 [1] CRAN (R 4.2.2)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.2.2)
#>  styler         1.8.1      2022-11-07 [1] CRAN (R 4.2.2)
#>  survival       3.4-0      2022-08-09 [1] CRAN (R 4.2.2)
#>  themis       * 1.0.0      2022-07-02 [1] CRAN (R 4.2.2)
#>  tibble       * 3.2.0      2023-03-08 [1] CRAN (R 4.2.2)
#>  tidymodels   * 1.0.0.9000 2022-12-02 [1] Github (tidymodels/tidymodels@9c71536)
#>  tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.2.2)
#>  tidyselect     1.2.0      2022-10-10 [1] CRAN (R 4.1.3)
#>  timechange     0.2.0      2023-01-11 [1] CRAN (R 4.2.3)
#>  timeDate       4022.108   2023-01-07 [1] CRAN (R 4.2.2)
#>  tune         * 1.0.1.9003 2023-03-23 [1] Github (tidymodels/tune@21ae429)
#>  utf8           1.2.3      2023-01-31 [1] CRAN (R 4.2.2)
#>  vctrs          0.5.2      2023-01-23 [1] CRAN (R 4.2.2)
#>  withr          2.5.0      2022-03-03 [1] CRAN (R 4.2.2)
#>  workflows    * 1.1.3      2023-02-22 [1] CRAN (R 4.2.3)
#>  workflowsets * 1.0.0      2022-07-12 [1] CRAN (R 4.2.2)
#>  xfun           0.37       2023-01-31 [1] CRAN (R 4.2.2)
#>  yaml           2.3.7      2023-01-23 [1] CRAN (R 4.2.2)
#>  yardstick    * 1.1.0      2022-09-07 [1] CRAN (R 4.2.2)
#> 
#>  [1] D:/Program Files/R/R-4.2.2/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Here's what you can do inside the for loop:

library(tidymodels)

tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)


data(cells)
cells <- cells %>% dplyr::select(-case)

set.seed(1)
val_set <- validation_split(cells)

basic_recipe <-
  recipe(class ~ ., data = cells) %>%
  step_YeoJohnson(all_predictors()) %>%
  step_normalize(all_predictors())

pca_recipe <-
  basic_recipe %>%
  step_pca(all_predictors(), num_comp = tune())

ss_recipe <-
  basic_recipe %>%
  step_spatialsign(all_predictors())

knn_mod <-
  nearest_neighbor(neighbors = tune(), weight_func = tune()) %>%
  set_engine("kknn") %>%
  set_mode("classification")

lr_mod <-
  logistic_reg() %>%
  set_engine("glm")

preproc <- list(none = basic_recipe, pca = pca_recipe, sp_sign = ss_recipe)
models <- list(knn = knn_mod, logistic = lr_mod)

cell_set <- workflow_set(preproc, models, cross = TRUE)

ctrl <- control_grid(save_workflow = TRUE)

set.seed(1)
cell_res <- 
  cell_set %>% 
  workflow_map(resamples = val_set, grid = 5, control = ctrl) %>% 
  mutate(
    fit = map(result, fit_best, metric = "roc_auc")
  )
cell_res
#> # A workflow set/tibble: 6 × 5
#>   wflow_id         info             option    result    fit       
#>   <chr>            <list>           <list>    <list>    <list>    
#> 1 none_knn         <tibble [1 × 4]> <opts[3]> <tune[+]> <workflow>
#> 2 none_logistic    <tibble [1 × 4]> <opts[3]> <rsmp[+]> <workflow>
#> 3 pca_knn          <tibble [1 × 4]> <opts[3]> <tune[+]> <workflow>
#> 4 pca_logistic     <tibble [1 × 4]> <opts[3]> <tune[+]> <workflow>
#> 5 sp_sign_knn      <tibble [1 × 4]> <opts[3]> <tune[+]> <workflow>
#> 6 sp_sign_logistic <tibble [1 × 4]> <opts[3]> <rsmp[+]> <workflow>
cell_res$fit[[1]]
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: nearest_neighbor()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 2 Recipe Steps
#> 
#> • step_YeoJohnson()
#> • step_normalize()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> 
#> Call:
#> kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(11L,     data, 5), kernel = ~"rank")
#> 
#> Type of response variable: nominal
#> Minimal misclassification: 0.1921743
#> Best kernel: rank
#> Best k: 11

Created on 2023-03-23 with reprex v2.0.2

Perfect. It's done. I'm grateful for your help without hesitation.

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.