How can I get feature weights from tidymodels with resamples, tuning, etc.? (how do I use extract_fit_engine()?)

Any example codes besides the documentation would be great.

It seems that extract_fit_engine() is the way to go. Below shows reprex of simple version, which I can get the coefficients from glmnet through tidymodels, but I've been stuck on trying to get column coefficients with resampling, tuning, etc. -- I've been unable to configure a correct workflow() to use pull_workflow_fit() or extract_fit_engine() (not shown below).

# for reprex
library(reprex)
library(plyr)
library(janitor)
#> 
#> Attaching package: 'janitor'
#> The following objects are masked from 'package:stats':
#> 
#>     chisq.test, fisher.test

library(lubridate)
#> 
#> Attaching package: 'lubridate'
#> The following objects are masked from 'package:base':
#> 
#>     date, intersect, setdiff, union
library(stringr)

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(glmnet)
#> Loading required package: Matrix
#> 
#> Attaching package: 'Matrix'
#> The following objects are masked from 'package:tidyr':
#> 
#>     expand, pack, unpack
#> Loaded glmnet 4.1-1

Sys.setenv(TZ = 'GMT')
options(scipen = 99999)
#rm(list=ls())

# Synthesise time series data
df = data.frame(yearr = sample(2015:2021, 2000, replace = TRUE),
                monthh = sample(1:12, 2000, replace = TRUE),
                dayy = sample(1:29, 2000, replace = TRUE)) |>
  mutate(datee = ymd(paste(yearr, monthh, dayy)),
         weekk = week(datee),
         quarterr = quarter(datee),
         semesterr = semester(datee),
         doyy = yday(datee),
         yy = sample(0:100, 2000, replace = TRUE) + (130 * yearr) + (2 * monthh) + (2 * weekk),
         dummyy = round(sample(0:1, 2000, replace = TRUE))) |>
  filter(!is.na(datee)) |>
  arrange(-desc(datee)) |>
  mutate(ii = row_number())
#> Warning: 3 failed to parse.

####################
# SIMPLE VERSION
####################

# Set up necessary elastic net objects
spec_enet = linear_reg(mode = 'regression', penalty = 0, mixture = 0.6) |>
  set_engine('glmnet')

rec_generic = df |>
  recipe(yy ~ .) |>
  step_rm(datee) |> # For analysis
  prep()

time_only = do.call('c', df |> select(datee)) # Save date variable from step_rm

df_baked = rec_generic |>
  bake(NULL)

# Easy way to get coefficients
fit_enet = spec_enet |>
  fit(yy ~ ., data = df_baked)

# Column coefficients
tidy_enet = fit_enet |>
  tidy()

tidy_enet # coefficients table AKA 'feature weights' or 'betas'
#> # A tibble: 10 x 3
#>    term            estimate penalty
#>    <chr>              <dbl>   <dbl>
#>  1 (Intercept) 153877.            0
#>  2 yearr           53.7           0
#>  3 monthh           1.86          0
#>  4 dayy             0             0
#>  5 weekk            0.406         0
#>  6 quarterr         1.27          0
#>  7 semesterr        0             0
#>  8 doyy             0.00750       0
#>  9 dummyy           0             0
#> 10 ii               0.263         0

###################################
# WITH RESAMPLES, GRID SEARCH, ETC.
###################################
# For tune_grid parameters
spec_enet_tune = linear_reg(mode = 'regression', penalty = tune(), mixture = tune()) |>
  set_engine('glmnet')

# rec_iteration is without prep()
rec_iteration = df_baked |>
  recipe(yy ~ .) |>
  step_zv(doyy)

folds = df |>
  mc_cv(prop = 3/4, times = 10)

metric = metric_set(rmse)

grid_pen_mix = expand_grid(penalty = seq(0, 100, by = 25),
                           mixture = seq(0, 1, by = 0.25))

# Goes under ctrl
glmnet_vars = function(x) {
  # `x` will be a workflow object
  mod <- extract_fit_engine(x) #library(hardhat) # https://tune.tidymodels.org/reference/extract-tune.html
  # `df` is the number of model terms for each penalty value
  tibble(penalty = mod$lambda, num_vars = mod$df)
}

ctrl <- control_grid(extract = glmnet_vars, verbose = TRUE)

tune_attempt2 = tune_grid(spec_enet_tune,
                          rec_iteration,
                          resamples = folds, # way above
                          grid = grid_pen_mix,
                          metrics = metric,
                          control = ctrl) # Where do I put the workflow?
#> i Resample01: preprocessor 1/1
#> v Resample01: preprocessor 1/1
...
#> v Resample10: preprocessor 1/1, model 5/5
#> i Resample10: preprocessor 1/1, model 5/5 (predictions)

tune_attempt2[[5]][[1]]
#> # A tibble: 25 x 4
#>    penalty mixture .extracts      .config              
#>      <dbl>   <dbl> <list>         <chr>                
#>  1     100    0    <try-errr [1]> Preprocessor1_Model01
#>  2     100    0    <try-errr [1]> Preprocessor1_Model02
#>  3     100    0    <try-errr [1]> Preprocessor1_Model03
#>  4     100    0    <try-errr [1]> Preprocessor1_Model04
#>  5     100    0    <try-errr [1]> Preprocessor1_Model05
#>  6     100    0.25 <try-errr [1]> Preprocessor1_Model06
#>  7     100    0.25 <try-errr [1]> Preprocessor1_Model07
#>  8     100    0.25 <try-errr [1]> Preprocessor1_Model08
#>  9     100    0.25 <try-errr [1]> Preprocessor1_Model09
#> 10     100    0.25 <try-errr [1]> Preprocessor1_Model10
#> # ... with 15 more rows
# Error in UseMethod("extract_fit_engine"): no applicable method for 'extract_fit_engine' for 'workflow'

# How do I get glmnet/linear_reg object from tune_grid?

Created on 2021-09-07 by the reprex package (v2.0.0)


I looked through below posts + answers, looked through documentation numerous times, but I've been struggling for some weeks on how to achieve this.

The answer to this is fairly complicated (some by tidymodels but mostly because of the nature of the glmnet model).

I have code below but will write an article for tidymodels.org (and maybe a convenience function from the tune package).

The important point is to use the special path_values argument that is described in the technical documentation about glmnet and tidymodels.

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
tidymodels_prefer()

data(Chicago)

Chicago <- Chicago %>% select(ridership, Clark_Lake, Austin, Harlem)

set.seed(1)
bt <- bootstraps(Chicago, times = 5)
rec <- 
  recipe(ridership ~ ., data = Chicago) %>% 
  step_normalize(all_numeric_predictors())

pen_vals <- 10^seq(-3, 0, length.out = 10)
grid <- crossing(penalty = pen_vals, mixture = c(0.5, 1.0))

glmnet_spec <- 
  linear_reg(penalty = tune(), mixture = tune()) %>% 
  set_engine("glmnet", path_values = pen_vals)

glmnet_wflow <- 
  workflow() %>% 
  add_model(glmnet_spec) %>% 
  add_recipe(rec)
get_glmnet_coefs <- function(x) {
  x %>% 
    extract_fit_engine() %>% 
    tidy()
}
parsnip_ctrl <- control_grid(extract = get_glmnet_coefs)

set.seed(2)
glmnet_res <- 
  glmnet_wflow %>% 
  tune_grid(
    resamples = bt,
    grid = grid,
    control = parsnip_ctrl
  )
glmnet_coefs <- 
  glmnet_res %>% 
  select(id, .extracts) %>% 
  unnest(.extracts) %>% 
  # The `penalty` column at this level is redundant, so 
  # we'll remove it. 
  select(id, mixture, .extracts) %>% 
  # Since you get all of the coefficients for each glmnet
  # fit, the values are replicated within a value of mixture.
  # We'll keep the first row so that we don't get the same
  # values over and over again. 
  group_by(id, mixture) %>% 
  slice(1) %>% 
  ungroup() %>% 
  unnest(.extracts) %>% 
  # Rename to be consistent with tidymodels
  rename(penalty = lambda)

glmnet_coefs %>% 
  select(id, penalty, mixture, term, estimate) %>% 
  filter(term == "Clark_Lake")
#> # A tibble: 100 × 5
#>    id         penalty mixture term       estimate
#>    <chr>        <dbl>   <dbl> <chr>         <dbl>
#>  1 Bootstrap1 1           0.5 Clark_Lake     3.05
#>  2 Bootstrap1 0.464       0.5 Clark_Lake     3.95
#>  3 Bootstrap1 0.215       0.5 Clark_Lake     4.79
#>  4 Bootstrap1 0.1         0.5 Clark_Lake     5.27
#>  5 Bootstrap1 0.0464      0.5 Clark_Lake     5.59
#>  6 Bootstrap1 0.0215      0.5 Clark_Lake     5.94
#>  7 Bootstrap1 0.01        0.5 Clark_Lake     6.13
#>  8 Bootstrap1 0.00464     0.5 Clark_Lake     6.22
#>  9 Bootstrap1 0.00215     0.5 Clark_Lake     6.25
#> 10 Bootstrap1 0.001       0.5 Clark_Lake     6.27
#> # … with 90 more rows

Created on 2021-09-08 by the reprex package (v2.0.0)

1 Like

As a another example, here is the same analysis with lm. The code is cleaner:

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
tidymodels_prefer()

data(Chicago)

Chicago <- Chicago %>% select(ridership, Clark_Lake, Austin, Harlem)

set.seed(1)
bt <- bootstraps(Chicago, times = 5)
get_lm_coefs <- function(x) {
  x %>% 
    # get the lm model
    extract_fit_engine() %>% 
    # put into a good format
    tidy()
}
tidy_ctrl <- control_grid(extract = get_lm_coefs)

lm_res <- 
  linear_reg() %>% 
  fit_resamples(ridership ~ ., resamples = bt, control = tidy_ctrl)
lm_res
#> # Resampling results
#> # Bootstrap sampling 
#> # A tibble: 5 × 5
#>   splits              id         .metrics         .notes           .extracts    
#>   <list>              <chr>      <list>           <list>           <list>       
#> 1 <split [5698/2107]> Bootstrap1 <tibble [2 × 4]> <tibble [0 × 1]> <tibble [1 ×…
#> 2 <split [5698/2093]> Bootstrap2 <tibble [2 × 4]> <tibble [0 × 1]> <tibble [1 ×…
#> 3 <split [5698/2091]> Bootstrap3 <tibble [2 × 4]> <tibble [0 × 1]> <tibble [1 ×…
#> 4 <split [5698/2088]> Bootstrap4 <tibble [2 × 4]> <tibble [0 × 1]> <tibble [1 ×…
#> 5 <split [5698/2082]> Bootstrap5 <tibble [2 × 4]> <tibble [0 × 1]> <tibble [1 ×…
lm_coefs <- 
  lm_res %>% 
  select(id, .extracts) %>% 
  unnest(.extracts) %>% 
  unnest(.extracts)

lm_coefs %>% 
  select(id, term, estimate, p.value) %>% 
  filter(term == "Clark_Lake")
#> # A tibble: 5 × 4
#>   id         term       estimate   p.value
#>   <chr>      <chr>         <dbl>     <dbl>
#> 1 Bootstrap1 Clark_Lake    0.962 7.58e-225
#> 2 Bootstrap2 Clark_Lake    0.869 1.18e-215
#> 3 Bootstrap3 Clark_Lake    0.876 1.91e-192
#> 4 Bootstrap4 Clark_Lake    0.908 7.61e-200
#> 5 Bootstrap5 Clark_Lake    0.898 3.49e-202

Created on 2021-09-08 by the reprex package (v2.0.0)

1 Like

Thank you very much for your reply. I really, really appreciate it.

However, I could not reproduce the same results.

  1. I need to include library(hardhat). But even then, it throws the same error as the original post no applicable method for 'extract_fit_engine' for 'workflow' (also for lm model). See my reprex below.

  2. I do not get the same columns as your columns at the end probably because of this error -- see result of colnames(). Do you know which versions you are using for these libraries?

library(reprex)

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(hardhat)
#> Warning: package 'hardhat' was built under R version 4.1.1
#> 
#> Attaching package: 'hardhat'
#> The following object is masked from 'package:tune':
#> 
#>     extract_recipe
tidymodels_prefer()

data(Chicago)

Chicago <- Chicago %>% select(ridership, Clark_Lake, Austin, Harlem)

set.seed(1)
bt <- bootstraps(Chicago, times = 5)

rec <- 
  recipe(ridership ~ ., data = Chicago) %>% 
  step_normalize(all_numeric_predictors())

pen_vals <- 10^seq(-3, 0, length.out = 10)
grid <- crossing(penalty = pen_vals, mixture = c(0.5, 1.0))

glmnet_spec <- 
  linear_reg(penalty = tune(), mixture = tune()) %>% 
  set_engine("glmnet", path_values = pen_vals)

glmnet_wflow <- 
  workflow() %>% 
  add_model(glmnet_spec) %>% 
  add_recipe(rec)

get_glmnet_coefs <- function(x) {
  x %>% 
    extract_fit_engine() %>% 
    tidy()
}
parsnip_ctrl <- control_grid(extract = get_glmnet_coefs)

set.seed(2)
glmnet_res <- 
  glmnet_wflow %>% 
  tune_grid(
    resamples = bt,
    grid = grid,
    control = parsnip_ctrl
  )

glmnet_coefs <- 
  glmnet_res %>% 
  select(id, .extracts) %>% 
  unnest(.extracts) %>% 
  # The `penalty` column at this level is redundant, so 
  # we'll remove it. 
  select(id, mixture, .extracts) %>% 
  # Since you get all of the coefficients for each glmnet
  # fit, the values are replicated within a value of mixture.
  # We'll keep the first row so that we don't get the same
  # values over and over again. 
  group_by(id, mixture) %>% 
  slice(1) %>% 
  ungroup() %>% 
  unnest(.extracts) %>% 
  # Rename to be consistent with tidymodels
  rename(penalty = lambda)
#> Error: Can't rename columns that don't exist.
#> x Column `lambda` doesn't exist.

### Error part
glmnet_coefs <- 
  glmnet_res %>% 
  select(id, .extracts) %>% 
  unnest(.extracts) %>% 
  # The `penalty` column at this level is redundant, so 
  # we'll remove it. 
  select(id, mixture, .extracts) %>% 
  # Since you get all of the coefficients for each glmnet
  # fit, the values are replicated within a value of mixture.
  # We'll keep the first row so that we don't get the same
  # values over and over again. 
  group_by(id, mixture) %>% 
  slice(1) %>% 
  ungroup() %>% 
  unnest(.extracts) #%>%
# REMOVED RENAME, BECAUSE OF ERROR BEFORE IT WITH 'UNNEST'

glmnet_coefs[['.extracts']][1]
#> [1] "Error in UseMethod(\"extract_fit_engine\") : \n  no applicable method for 'extract_fit_engine' applied to an object of class \"workflow\"\n"
colnames(glmnet_coefs)
#> [1] "id"        "mixture"   ".extracts"
###

Created on 2021-09-09 by the reprex package (v2.0.0)

You probably need a few updated versions. Here is my version info:

─ Session info ────────────────────────────────────────────────────────────────────────────────────────────────
 setting  value                       
 version  R version 4.1.0 (2021-05-18)
 os       macOS Catalina 10.15.7      
 system   x86_64, darwin17.0          
 ui       RStudio                     
 language (EN)                        
 collate  en_US.UTF-8                 
 ctype    en_US.UTF-8                 
 tz       America/New_York            
 date     2021-09-09                  

─ Packages ────────────────────────────────────────────────────────────────────────────────────────────────────
 package      * version    date       lib source                           
 assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.1.0)                   
 backports      1.2.1      2020-12-09 [1] standard (@1.2.1)                
 broom        * 0.7.9      2021-07-27 [1] CRAN (R 4.1.0)                   
 cachem         1.0.6      2021-08-19 [1] CRAN (R 4.1.0)                   
 class          7.3-19     2021-05-03 [1] CRAN (R 4.1.0)                   
 cli            3.0.1      2021-07-17 [1] CRAN (R 4.1.0)                   
 codetools      0.2-18     2020-11-04 [1] CRAN (R 4.1.0)                   
 colorspace     2.0-2      2021-06-24 [1] CRAN (R 4.1.0)                   
 conflicted     1.0.4      2019-06-21 [1] standard (@1.0.4)                
 crayon         1.4.1      2021-02-08 [1] CRAN (R 4.1.0)                   
 DBI            1.1.1      2021-01-15 [1] standard (@1.1.1)                
 dials        * 0.0.9.9000 2021-09-08 [1] Github (tidymodels/dials@c291516)
 DiceDesign     1.9        2021-02-13 [1] standard (@1.9)                  
 digest         0.6.27     2020-10-24 [1] CRAN (R 4.1.0)                   
 dplyr        * 1.0.7      2021-06-18 [1] CRAN (R 4.1.0)                   
 ellipsis       0.3.2      2021-04-29 [1] standard (@0.3.2)                
 fansi          0.5.0      2021-05-25 [1] standard (@0.5.0)                
 fastmap        1.1.0      2021-01-25 [1] standard (@1.1.0)                
 foreach        1.5.1      2020-10-15 [1] CRAN (R 4.1.0)                   
 furrr          0.2.3      2021-06-25 [1] CRAN (R 4.1.0)                   
 future         1.21.0     2020-12-10 [1] CRAN (R 4.1.0)                   
 generics       0.1.0.9000 2021-07-13 [1] Github (r-lib/generics@7cc4465)  
 ggplot2      * 3.3.5      2021-06-25 [1] CRAN (R 4.1.0)                   
 glmnet       * 4.1-2      2021-06-24 [1] CRAN (R 4.1.0)                   
 globals        0.14.0     2020-11-22 [1] CRAN (R 4.1.0)                   
 glue           1.4.2      2020-08-27 [1] CRAN (R 4.1.0)                   
 gower          0.2.2      2020-06-23 [1] CRAN (R 4.1.0)                   
 GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.1.0)                   
 gtable         0.3.0      2019-03-25 [1] CRAN (R 4.1.0)                   
 hardhat        0.1.6      2021-07-14 [1] CRAN (R 4.1.0)                   
 infer        * 1.0.0      2021-08-13 [1] CRAN (R 4.1.0)                   
 ipred          0.9-11     2021-03-12 [1] standard (@0.9-11)               
 iterators      1.0.13     2020-10-15 [1] CRAN (R 4.1.0)                   
 lattice        0.20-44    2021-05-02 [1] CRAN (R 4.1.0)                   
 lava           1.6.9      2021-03-11 [1] standard (@1.6.9)                
 lhs            1.1.1      2020-10-05 [1] CRAN (R 4.1.0)                   
 lifecycle      1.0.0      2021-02-15 [1] CRAN (R 4.1.0)                   
 listenv        0.8.0      2019-12-05 [1] CRAN (R 4.1.0)                   
 lubridate      1.7.10     2021-02-26 [1] CRAN (R 4.1.0)                   
 magrittr       2.0.1      2020-11-17 [1] CRAN (R 4.1.0)                   
 MASS           7.3-54     2021-05-03 [1] CRAN (R 4.1.0)                   
 Matrix       * 1.3-4      2021-06-01 [1] standard (@1.3-4)                
 modeldata    * 0.1.1      2021-07-14 [1] CRAN (R 4.1.0)                   
 munsell        0.5.0      2018-06-12 [1] CRAN (R 4.1.0)                   
 nnet           7.3-16     2021-05-03 [1] CRAN (R 4.1.0)                   
 parallelly     1.27.0     2021-07-19 [1] CRAN (R 4.1.0)                   
 parsnip      * 0.1.7      2021-07-21 [1] CRAN (R 4.1.0)                   
 pillar         1.6.2      2021-07-29 [1] CRAN (R 4.1.0)                   
 pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.1.0)                   
 plyr           1.8.6      2020-03-03 [1] CRAN (R 4.1.0)                   
 pROC           1.17.0.1   2021-01-13 [1] standard (@1.17.0.)              
 prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.1.0)                   
 purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.1.0)                   
 R6             2.5.1      2021-08-19 [1] CRAN (R 4.1.0)                   
 Rcpp           1.0.7      2021-07-07 [1] CRAN (R 4.1.0)                   
 recipes      * 0.1.16     2021-04-16 [1] CRAN (R 4.1.0)                   
 rlang        * 0.4.11     2021-04-30 [1] CRAN (R 4.1.0)                   
 rpart          4.1-15     2019-04-12 [1] CRAN (R 4.1.0)                   
 rsample      * 0.1.0      2021-05-08 [1] CRAN (R 4.1.0)                   
 rstudioapi     0.13       2020-11-12 [1] CRAN (R 4.1.0)                   
 scales       * 1.1.1      2020-05-11 [1] CRAN (R 4.1.0)                   
 sessioninfo    1.1.1      2018-11-05 [1] CRAN (R 4.1.0)                   
 shape          1.4.6      2021-05-19 [1] standard (@1.4.6)                
 survival       3.2-11     2021-04-26 [1] CRAN (R 4.1.0)                   
 tibble       * 3.1.3      2021-07-23 [1] CRAN (R 4.1.0)                   
 tidymodels   * 0.1.3      2021-04-19 [1] CRAN (R 4.1.0)                   
 tidyr        * 1.1.3      2021-03-03 [1] standard (@1.1.3)                
 tidyselect     1.1.1      2021-04-30 [1] standard (@1.1.1)                
 timeDate       3043.102   2018-02-21 [1] CRAN (R 4.1.0)                   
 tune         * 0.1.6      2021-07-21 [1] CRAN (R 4.1.0)                   
 utf8           1.2.2      2021-07-24 [1] CRAN (R 4.1.0)                   
 vctrs        * 0.3.8      2021-04-29 [1] standard (@0.3.8)                
 withr          2.4.2      2021-04-18 [1] standard (@2.4.2)                
 workflows    * 0.2.3      2021-07-16 [1] CRAN (R 4.1.0)                   
 workflowsets * 0.1.0      2021-07-22 [1] CRAN (R 4.1.0)                   
 yaml           2.2.1      2020-02-01 [1] CRAN (R 4.1.0)                   
 yardstick    * 0.0.8      2021-03-28 [1] CRAN (R 4.1.0)                   

[1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library
> 
1 Like

Thank you, I found the packages that wouldn't automatically update and manually updated them to match your versions. Everything seems to work now.

Thank you very much

1 Like