Permutation-based variable importance with glmnet fit model

{vip} package provides variant importance with model agonistic methods like permutation. The permutation method for glmnet model needs additional argument newx for predict.glmnet().

The second question is which metric should be selected for poison regression.

library(poissonreg)
#> Loading required package: parsnip
library(tidyverse)
#> Warning: package 'tibble' was built under R version 4.0.3
library(tidymodels)
#> -- Attaching packages --------------------------------- tidymodels 0.1.1.9000 --
#> √ broom     0.7.2          √ rsample   0.0.8.9000
#> √ dials     0.0.9.9000     √ tune      0.1.1.9000
#> √ infer     0.5.3          √ workflows 0.2.1.9000
#> √ modeldata 0.1.0          √ yardstick 0.0.7     
#> √ recipes   0.1.14
#> Warning: package 'broom' was built under R version 4.0.3
#> Warning: package 'modeldata' was built under R version 4.0.3
#> Warning: package 'recipes' was built under R version 4.0.3
#> -- Conflicts ----------------------------------------- tidymodels_conflicts() --
#> x scales::discard() masks purrr::discard()
#> x dplyr::filter()   masks stats::filter()
#> x recipes::fixed()  masks stringr::fixed()
#> x dplyr::lag()      masks stats::lag()
#> x yardstick::spec() masks readr::spec()
#> x recipes::step()   masks stats::step()
library(vip)
#> 
#> Attaching package: 'vip'
#> The following object is masked from 'package:utils':
#> 
#>     vi

## vi with permute method for model with glmnet engine needs additional argument "newx" for predict.glmnet
glmnet_fit <- poisson_reg() %>%
  set_engine("glmnet") %>%
  fit(count ~ (.) ^ 2, data = seniors)

glmnet_fit %>%
  vi(
    method = "permute",
    train = seniors,
    target = count,
    metric = "rsquared",
    pred_wrapper = predict,
    nsim = 10
  )
#> Error in predict.glmnet(object, newdata = train_x): You need to supply a value for 'newx'

## vi with model method works
glmnet_fit %>%
  vi(method = "model")
#> # A tibble: 6 x 3
#>   Variable                  Importance Sign 
#>   <chr>                          <dbl> <chr>
#> 1 marijuanayes:cigaretteyes      2.78  POS  
#> 2 marijuanayes:alcoholyes        2.42  POS  
#> 3 cigaretteyes:alcoholyes        2.03  POS  
#> 4 alcoholyes                     0.492 POS  
#> 5 cigaretteyes                  -1.85  NEG  
#> 6 marijuanayes                  -4.68  NEG

## glm engine works
glm_fit <- poisson_reg() %>%
  set_engine("glm") %>%
  fit(count ~ (.) ^ 2, data = seniors)

glm_fit %>%
  vi(
    method = "permute",
    train = seniors,
    target = "count",
    metric = "rsquared",
    pred_wrapper = predict,
    nsim = 10
  )
#> # A tibble: 3 x 3
#>   Variable  Importance StDev
#>   <chr>          <dbl> <dbl>
#> 1 alcohol        0.453 0.214
#> 2 cigarette      0.249 0.253
#> 3 marijuana      0.247 0.163

Created on 2020-11-03 by the reprex package (v0.3.0)

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.