tidymodels: fastest way to extract a model object from fit_resamples() results

What would be most efficient way to extract my parsnip model object from fitted resamples (tune::fit_resample())?

When i want to train a model with cross-validation, i can either go with tune::tune_grid() oder fit_resamples().

Lets say i know the best parameters for my algortihms, so i dont need any paramter tunig, which means i decide to go with fit_resamples().
If i had decided to go with tune_grid() i usually set up a workflow since i evaluate different models after tune_grid ran: I go for tune::show_best() and tune::select_best() to explore and extract the best parameters for my model. Then i go for tune::finalize_workflow(), workflows::pull_wokrflow_fit() to extract my model object. Further when i want to see predictions i go for tune::last_fit() and tune::collect_predictions()

All these steps seem redundant when i go with fit_resamples(), since i basically only have one model with stable parameters. So all these steps above are not neccesarry, nevertheless i have to go trough them. Do I?

After fit_resamples() is performed, i get a tibble with information about .splits, .metrics, .notes, etc.
So my question really comes down to: what is the fastes way from that tibble to my final model object?

The important thing to realize about fit_resamples() is that its purpose is to measure performance. The models that you train in fit_resamples() are not kept or used later.

Let's imagine that you know the parameters you want to use for an SVM model.

library(tidymodels)
#> ── Attaching packages ─────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0      ✓ recipes   0.1.13
#> ✓ dials     0.0.8      ✓ rsample   0.0.7 
#> ✓ dplyr     1.0.0      ✓ tibble    3.0.3 
#> ✓ ggplot2   3.3.2      ✓ tidyr     1.1.0 
#> ✓ infer     0.5.3      ✓ tune      0.1.1 
#> ✓ modeldata 0.0.2      ✓ workflows 0.1.2 
#> ✓ parsnip   0.1.2      ✓ yardstick 0.0.7 
#> ✓ purrr     0.3.4
#> ── Conflicts ────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()

## pretend this is your training data
data("hpc_data")

svm_spec <- svm_poly(degree = 1, cost = 1/4) %>%
  set_engine("kernlab") %>%
  set_mode("regression")

svm_wf <- workflow() %>%
  add_model(svm_spec) %>%
  add_formula(compounds ~ .)

hpc_folds <- vfold_cv(hpc_data)

svm_rs <- svm_wf %>%
  fit_resamples(
    resamples = hpc_folds
  )

svm_rs
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 x 4
#>    splits             id     .metrics         .notes          
#>    <list>             <chr>  <list>           <list>          
#>  1 <split [3.9K/434]> Fold01 <tibble [2 × 3]> <tibble [0 × 1]>
#>  2 <split [3.9K/433]> Fold02 <tibble [2 × 3]> <tibble [0 × 1]>
#>  3 <split [3.9K/433]> Fold03 <tibble [2 × 3]> <tibble [0 × 1]>
#>  4 <split [3.9K/433]> Fold04 <tibble [2 × 3]> <tibble [0 × 1]>
#>  5 <split [3.9K/433]> Fold05 <tibble [2 × 3]> <tibble [0 × 1]>
#>  6 <split [3.9K/433]> Fold06 <tibble [2 × 3]> <tibble [0 × 1]>
#>  7 <split [3.9K/433]> Fold07 <tibble [2 × 3]> <tibble [0 × 1]>
#>  8 <split [3.9K/433]> Fold08 <tibble [2 × 3]> <tibble [0 × 1]>
#>  9 <split [3.9K/433]> Fold09 <tibble [2 × 3]> <tibble [0 × 1]>
#> 10 <split [3.9K/433]> Fold10 <tibble [2 × 3]> <tibble [0 × 1]>

There are no fitted models in this output. Models were fitted to each of these resamples, but you don't want to use them for anything; they are thrown away because their only purpose is for computing the .metrics to estimate performance.

If you want a model to use to predict on new data, you need to go back to your whole training set and fit your model once again, with the entire training set.

svm_fit <- svm_wf %>%
  fit(hpc_data)

svm_fit
#> ══ Workflow [trained] ═══════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: svm_poly()
#> 
#> ── Preprocessor ─────────────────────────────────────────────────────
#> compounds ~ .
#> 
#> ── Model ────────────────────────────────────────────────────────────
#> Support Vector Machine object of class "ksvm" 
#> 
#> SV type: eps-svr  (regression) 
#>  parameter : epsilon = 0.1  cost C = 0.25 
#> 
#> Polynomial kernel function. 
#>  Hyperparameters : degree =  1  scale =  1  offset =  1 
#> 
#> Number of Support Vectors : 2827 
#> 
#> Objective Function Value : -284.7255 
#> Training error : 0.835421

Created on 2020-07-17 by the reprex package (v0.3.0)

This final object is one that you can use with pull_workflow_fit() for variable importance or similar.

1 Like

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.