Using ALEPlot for XGBoost

Trying to explore ALEPlots from the ALEPlot package for xgboost models, struggling to get the plots out... any help?

Reprex adapated from Julia SIlge's blog below has simplified tuning, limited trees heavily.

library(tidyverse)

ratings <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-01-25/ratings.csv")
#> Rows: 21831 Columns: 10
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> chr (3): name, url, thumbnail
#> dbl (7): num, id, year, rank, average, bayes_average, users_rated
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
details <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2022/2022-01-25/details.csv")
#> Rows: 21631 Columns: 23
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> chr (10): primary, description, boardgamecategory, boardgamemechanic, boardg...
#> dbl (13): num, id, yearpublished, minplayers, maxplayers, playingtime, minpl...
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

ratings_joined <-
  ratings %>%
  left_join(details, by = "id")

ggplot(ratings_joined, aes(average)) +
  geom_histogram(alpha = 0.8)
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
library(tidymodels)

set.seed(123)
game_split <-
  ratings_joined %>%
  select(name, average, matches("min|max"), boardgamecategory) %>%
  na.omit() %>%
  initial_split(strata = average)

game_train <- training(game_split)
game_test <- testing(game_split)

set.seed(234)
game_folds <- vfold_cv(game_train, strata = average)
game_folds
#> #  10-fold cross-validation using stratification 
#> # A tibble: 10 × 2
#>    splits               id    
#>    <list>               <chr> 
#>  1 <split [14407/1602]> Fold01
#>  2 <split [14407/1602]> Fold02
#>  3 <split [14407/1602]> Fold03
#>  4 <split [14408/1601]> Fold04
#>  5 <split [14408/1601]> Fold05
#>  6 <split [14408/1601]> Fold06
#>  7 <split [14408/1601]> Fold07
#>  8 <split [14408/1601]> Fold08
#>  9 <split [14410/1599]> Fold09
#> 10 <split [14410/1599]> Fold10

library(textrecipes)

split_category <- function(x) {
  x %>%
    stringr::str_split(", ") %>%
    purrr::map(stringr::str_remove_all, "[:punct:]") %>%
    purrr::map(stringr::str_squish) %>%
    purrr::map(stringr::str_to_lower) %>%
    purrr::map(stringr::str_replace_all, " ", "_")
}

game_rec <-
  recipe(average ~ ., data = game_train) %>%
  update_role(name, new_role = "id") %>%
  step_tokenize(boardgamecategory, custom_token = split_category) %>%
  step_tokenfilter(boardgamecategory, max_tokens = 30) %>%
  step_tf(boardgamecategory)

## just to make sure this works as expected
game_prep <- prep(game_rec)
bake(game_prep, new_data = NULL) %>% str()
#> tibble [16,009 × 37] (S3: tbl_df/tbl/data.frame)
#>  $ name                                      : Factor w/ 15781 levels "'65: Squad-Level Combat in the Jungles of Vietnam",..: 10857 8587 14642 858 15728 6819 13313 1490 3143 9933 ...
#>  $ minplayers                                : num [1:16009] 2 2 2 4 2 1 2 2 4 2 ...
#>  $ maxplayers                                : num [1:16009] 6 8 10 10 6 8 6 2 16 6 ...
#>  $ minplaytime                               : num [1:16009] 120 60 30 30 60 20 60 30 60 45 ...
#>  $ maxplaytime                               : num [1:16009] 120 180 30 30 90 20 60 30 60 45 ...
#>  $ minage                                    : num [1:16009] 10 8 6 12 15 6 8 8 13 8 ...
#>  $ average                                   : num [1:16009] 5.59 4.37 5.41 5.79 5.8 5.62 4.31 4.66 5.68 5.14 ...
#>  $ tf_boardgamecategory_abstract_strategy    : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_action_dexterity     : int [1:16009] 0 0 0 0 0 1 0 0 0 0 ...
#>  $ tf_boardgamecategory_adventure            : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_ancient              : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_animals              : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_bluffing             : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_card_game            : int [1:16009] 0 0 1 1 0 0 0 0 0 1 ...
#>  $ tf_boardgamecategory_childrens_game       : int [1:16009] 0 0 0 0 0 0 1 1 0 0 ...
#>  $ tf_boardgamecategory_deduction            : int [1:16009] 0 0 0 0 0 0 0 1 0 0 ...
#>  $ tf_boardgamecategory_dice                 : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_economic             : int [1:16009] 0 1 0 0 0 0 1 0 0 0 ...
#>  $ tf_boardgamecategory_exploration          : int [1:16009] 0 0 0 0 1 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_fantasy              : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_fighting             : int [1:16009] 0 0 0 0 1 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_horror               : int [1:16009] 0 0 0 0 1 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_humor                : int [1:16009] 0 0 0 1 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_medieval             : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_miniatures           : int [1:16009] 0 0 0 0 1 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_movies_tv_radio_theme: int [1:16009] 0 0 1 0 1 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_nautical             : int [1:16009] 0 0 0 0 0 0 0 1 0 0 ...
#>  $ tf_boardgamecategory_negotiation          : int [1:16009] 0 1 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_party_game           : int [1:16009] 0 0 0 1 0 1 0 0 1 0 ...
#>  $ tf_boardgamecategory_print_play           : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_puzzle               : int [1:16009] 0 0 0 0 0 0 0 0 1 0 ...
#>  $ tf_boardgamecategory_racing               : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_realtime             : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_science_fiction      : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...
#>  $ tf_boardgamecategory_trivia               : int [1:16009] 0 0 0 0 0 0 0 0 1 0 ...
#>  $ tf_boardgamecategory_wargame              : int [1:16009] 1 0 0 0 0 0 0 1 0 0 ...
#>  $ tf_boardgamecategory_world_war_ii         : int [1:16009] 0 0 0 0 0 0 0 0 0 0 ...


xgb_spec <-
  boost_tree(
    trees = 100,
    mtry = tune(),
    min_n = 100,
    learn_rate = 0.01
  ) %>%
  set_engine("xgboost") %>%
  set_mode("regression")

xgb_wf <- workflow(game_rec, xgb_spec)
xgb_wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 3 Recipe Steps
#> 
#> • step_tokenize()
#> • step_tokenfilter()
#> • step_tf()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#> 
#> Main Arguments:
#>   mtry = tune()
#>   trees = 100
#>   min_n = 100
#>   learn_rate = 0.01
#> 
#> Computational engine: xgboost

library(finetune)
doParallel::registerDoParallel()

set.seed(234)
xgb_game_rs <-
  tune_race_anova(
    xgb_wf,
    game_folds,
    grid = 5,
    control = control_race(verbose_elim = TRUE)
  )
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> ℹ Racing will minimize the rmse metric.ℹ Resamples are analyzed in a random order.ℹ Fold10: 3 eliminated; 2 candidates remain.
#> ℹ Fold06: 0 eliminated; 2 candidates remain.
#> ℹ Fold08: 0 eliminated; 2 candidates remain.
#> ℹ Fold01: 0 eliminated; 2 candidates remain.
#> ℹ Fold04: 0 eliminated; 2 candidates remain.
#> ℹ Fold02: 0 eliminated; 2 candidates remain.
#> ℹ Fold09: 0 eliminated; 2 candidates remain.

xgb_game_rs
#> # Tuning results
#> # 10-fold cross-validation using stratification 
#> # A tibble: 10 × 5
#>    splits               id     .order .metrics          .notes          
#>    <list>               <chr>   <int> <list>            <list>          
#>  1 <split [14407/1602]> Fold03      1 <tibble [10 × 5]> <tibble [0 × 3]>
#>  2 <split [14408/1601]> Fold05      2 <tibble [10 × 5]> <tibble [0 × 3]>
#>  3 <split [14410/1599]> Fold10      3 <tibble [10 × 5]> <tibble [0 × 3]>
#>  4 <split [14408/1601]> Fold06      4 <tibble [4 × 5]>  <tibble [0 × 3]>
#>  5 <split [14408/1601]> Fold08      5 <tibble [4 × 5]>  <tibble [0 × 3]>
#>  6 <split [14407/1602]> Fold01      6 <tibble [4 × 5]>  <tibble [0 × 3]>
#>  7 <split [14408/1601]> Fold04      7 <tibble [4 × 5]>  <tibble [0 × 3]>
#>  8 <split [14407/1602]> Fold02      8 <tibble [4 × 5]>  <tibble [0 × 3]>
#>  9 <split [14410/1599]> Fold09      9 <tibble [4 × 5]>  <tibble [0 × 3]>
#> 10 <split [14408/1601]> Fold07     10 <tibble [4 × 5]>  <tibble [0 × 3]>

show_best(xgb_game_rs)
#> Warning: No value of `metric` was given; metric 'rmse' will be used.
#> # A tibble: 2 × 7
#>    mtry .metric .estimator  mean     n std_err .config             
#>   <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1    27 rmse    standard    2.32    10 0.00405 Preprocessor1_Model5
#> 2    29 rmse    standard    2.32    10 0.00411 Preprocessor1_Model1

xgb_last <-
  xgb_wf %>%
  finalize_workflow(select_best(xgb_game_rs, "rmse")) %>%
  last_fit(game_split)

xgb_last
#> # Resampling results
#> # Manual resampling 
#> # A tibble: 1 × 6
#>   splits               id              .metrics .notes   .predictions .workflow 
#>   <list>               <chr>           <list>   <list>   <list>       <list>    
#> 1 <split [16009/5339]> train/test spl… <tibble> <tibble> <tibble>     <workflow>

library(vip)
#> 
#> Attaching package: 'vip'
#> 
#> The following object is masked from 'package:utils':
#> 
#>     vi

xgb_fit <- extract_fit_parsnip(xgb_last)
vip(xgb_fit, geom = "point", num_features = 12)

library(SHAPforxgboost)

game_shap <-
  shap.prep(
    xgb_model = extract_fit_engine(xgb_fit),
    X_train = bake(game_prep,
                   has_role("predictor"),
                   new_data = NULL,
                   composition = "matrix"
    )
  )

#shap.plot.summary(game_shap)

shap.plot.dependence(
  game_shap,
  x = "minage",
  color_feature = "minplayers",
  size0 = 1.2,
  smooth = FALSE, add_hist = TRUE
)

library(ALEPlot)

xgb_engine <- extract_fit_engine(xgb_fit)

yhat <- function(xgb_engine, newdata) as.numeric(predict(xgb_engine, newdata))

game_train_df <- game_train %>% 
  select(-average) %>% 
  as.data.frame()

ALEPlot(
  X = game_train_df,
  X.model = xgb_engine,
  pred.fun = yhat,
  J = 1,
  K = 50,
  NA.plot = TRUE
)
#> [1] "error:  class(X[,J]) must be either factor or numeric or integer"
#> Error in ALEPlot(X = game_train_df, X.model = xgb_engine, pred.fun = yhat, : object 'x' not found

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.