How do I extract model information like model engine from a tuned tidymodels object?

I have a fitted model object and a tuned model object. With the fitted model, I am able to extract model information spec, engine, class, etc. See example below using a fitted Xgboost model called fit_xgboost. Using the commands like class(), fit_xgboost$trained, etc, I can see the model information like class, spec and computational engine.


> class(fit_xgboost)
[1] "workflow"

> fit_xgboost$trained
[1] TRUE

> fit_xgboost$fit$actions$model$spec
Boosted Tree Model Specification (regression)
Computational engine: xgboost 

> fit_xgboost$fit$fit$spec$engine
[1] "xgboost"


How do I get the same information when the model object has been tuned? I can see the class of a tuned model is "tune_results" "tbl_df" "tbl" "data.frame" , but is there a way to get the spec, computational engine, etc. For context; I'm creating some custom functions for tuned model objects and need a way to check the model information and provide feedback to the user based on certain conditions. Maybe this is not the right approach, so I'm happy to take feedback/better suggestions.

See code to produce similar model objects below -

# Libraries
library(tidyverse)
library(tidymodels)


# Data
mtcars_data <- mtcars %>% as_tibble()

# Resamples
set.seed(123)
mtcars_resamples <- vfold_cv(mtcars_data, v = 5)

# Recipe
recipe_spec <- recipe(mpg ~., data = mtcars_data)

# Model Spec
fit_xgboost <- workflow() %>% 
    add_model(
        spec = boost_tree() %>% 
            set_mode("regression") %>% 
            set_engine("xgboost")
    ) %>% 
    add_recipe(recipe_spec) %>% 
    fit(mtcars_train)





# Tuning
xgboost_spec <-boost_tree(
    trees          = 500,
    min_n          = tune(),
    tree_depth     = tune(),
    learn_rate     = tune(),
    loss_reduction = tune(),
    sample_size    = tune()
) %>%
    set_mode("regression") %>%
    set_engine("xgboost") 



tune_xgboost <- tune_grid(
    object    = workflow() %>% add_recipe(recipe_spec) %>% add_model(xgboost_spec),
    resamples = mtcars_resamples,
    grid      = grid_latin_hypercube(x = parameters(xgboost_spec), size = 5),
    control   = control_grid(save_pred = TRUE, verbose = FALSE),
    metrics   = metric_set(mae, rmse, rsq)
    
)


Thanks for the post! These are good questions.

The extract_*() functions and the extract argument to control_grid() can take you a long way. :slight_smile: They all work on a variety of tidymodels objects so that you don't have to subset with e.g. $fit$fit when pulling out elements of objects.

Using a slightly modified version of your code:

# Libraries
library(tidyverse)
library(tidymodels)

# Data
mtcars_data <- mtcars[rep(1:32, 100), ] %>% as_tibble()

# Resamples
set.seed(123)
mtcars_resamples <- vfold_cv(mtcars_data, v = 5)

# Recipe
recipe_spec <- recipe(mpg ~., data = mtcars_data)

# Model Spec
fit_xgboost <- workflow() %>% 
  add_model(
    spec = boost_tree() %>% 
      set_mode("regression") %>% 
      set_engine("xgboost")
  ) %>% 
  add_recipe(recipe_spec) %>% 
  fit(mtcars)

# Tuning
xgboost_spec <-boost_tree(
  trees          = 500,
  min_n          = tune(),
  tree_depth     = tune(),
  learn_rate     = tune(),
  loss_reduction = tune(),
  sample_size    = tune()
) %>%
  set_mode("regression") %>%
  set_engine("xgboost") 

Starting with your code, I've additionally passed the save_workflow = TRUE and extract arguments.

tune_xgboost <- tune_grid(
  object    = workflow() %>% add_recipe(recipe_spec) %>% add_model(xgboost_spec),
  resamples = mtcars_resamples,
  grid      = grid_latin_hypercube(x = extract_parameter_set_dials(xgboost_spec), size = 5),
  control   = control_grid(save_pred = TRUE, verbose = FALSE, save_workflow = TRUE, extract = identity),
  metrics   = metric_set(mae, rmse, rsq)
)

tune_xgboost
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 6
#>   splits             id    .metrics          .notes   .extracts .predictions
#>   <list>             <chr> <list>            <list>   <list>    <list>      
#> 1 <split [2560/640]> Fold1 <tibble [15 × 9]> <tibble> <tibble>  <tibble>    
#> 2 <split [2560/640]> Fold2 <tibble [15 × 9]> <tibble> <tibble>  <tibble>    
#> 3 <split [2560/640]> Fold3 <tibble [15 × 9]> <tibble> <tibble>  <tibble>    
#> 4 <split [2560/640]> Fold4 <tibble [15 × 9]> <tibble> <tibble>  <tibble>    
#> 5 <split [2560/640]> Fold5 <tibble [15 × 9]> <tibble> <tibble>  <tibble>

save_workflow will bring along the workflow passed as object in the resulting object:

wf <- extract_workflow(tune_xgboost)

wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 0 Recipe Steps
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#> 
#> Main Arguments:
#>   trees = 500
#>   min_n = tune()
#>   tree_depth = tune()
#>   learn_rate = tune()
#>   loss_reduction = tune()
#>   sample_size = tune()
#> 
#> Computational engine: xgboost

extract_spec_parsnip(wf)
#> Boosted Tree Model Specification (regression)
#> 
#> Main Arguments:
#>   trees = 500
#>   min_n = tune()
#>   tree_depth = tune()
#>   learn_rate = tune()
#>   loss_reduction = tune()
#>   sample_size = tune()
#> 
#> Computational engine: xgboost

The extract argument takes a function to extract any portion of the workflow trained on the given resample and retain it in the resulting object. extract = identity gives you the whole fitted workflow, but you could also extract_spec_parsnip(), extract_fit_engine(), etc.

collect_extracts(tune_xgboost)
#> # A tibble: 25 × 8
#>    id    min_n tree_depth learn_rate loss_reduction sample_size .extracts 
#>    <chr> <int>      <int>      <dbl>          <dbl>       <dbl> <list>    
#>  1 Fold1    15          6    0.00175  0.00000851          0.720 <workflow>
#>  2 Fold1     2         12    0.0949   0.00000000107       0.202 <workflow>
#>  3 Fold1    35          3    0.306    0.00239             0.636 <workflow>
#>  4 Fold1    23         13    0.00320  0.975               0.920 <workflow>
#>  5 Fold1    29          9    0.0312   0.00000117          0.426 <workflow>
#>  6 Fold2    15          6    0.00175  0.00000851          0.720 <workflow>
#>  7 Fold2     2         12    0.0949   0.00000000107       0.202 <workflow>
#>  8 Fold2    35          3    0.306    0.00239             0.636 <workflow>
#>  9 Fold2    23         13    0.00320  0.975               0.920 <workflow>
#> 10 Fold2    29          9    0.0312   0.00000117          0.426 <workflow>
#> # ℹ 15 more rows
#> # ℹ 1 more variable: .config <chr>

Hope this is helpful for you!

Created on 2023-04-28 with reprex v2.0.2

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.