tidymodels: tuning scale_pos_weight in xgboost

Hi, all. I'm confused as to how to tune the scale_pos_weight hyperparameter in xgboost models in tidymodels. I've read the documentation but clearly am not implementing correctly. Would love help with this one!

xgb_spec <- boost_tree(
    trees = tune(),
    min_n = tune(),
    mtry = tune(),
    learn_rate = tune(),
    scale_pos_weight = tune()
  ) %>%
  set_engine("xgboost") %>%
  set_mode("classification")

Error in boost_tree(trees = tune(), min_n = tune(), mtry = tune(), learn_rate = tune(), :
unused argument (scale_pos_weight = tune())

I have also tried:

  1. Using scale_pos_weight(range = c(10, 200))
  2. Putting it in the set_engine("xgboost", scale_pos_weight = tune())

I know that I can pass a given scale_pos_weight value to xgboost via the set_engine statement, but I'm stumped as to how to tune it though from the closed issues on GitHub, it is clearly possible.

Would appreciate any help!

Thank you all so much!

Hi,
You should check your tidymodels libraries are up-to-date (you can use tidymodels::tidymodels_update()

I ran the code from the Github closed issue and it works fine.

HTH

library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

tidymodels_prefer()
options(tidymodels.dark = TRUE)

library(mlbench)

data("PimaIndiansDiabetes")

set.seed(24)

df <- PimaIndiansDiabetes %>%
    mutate(diabetes = fct_relevel(diabetes, 'pos'))

xgb_model_1 <-
    boost_tree(trees = 150,
               tree_depth = tune()) %>%
    set_engine('xgboost', scale_pos_weight = tune(), eval_metric = 'auc') %>%
    set_mode('classification')

xgb_grid <- grid_regular(scale_pos_weight(range = c(-3, -1), trans = log10_trans()),
                         tree_depth(range = c(2, 4)))
xgb_grid
#> # A tibble: 9 x 2
#>   scale_pos_weight tree_depth
#>              <dbl>      <int>
#> 1            0.001          2
#> 2            0.01           2
#> 3            0.1            2
#> 4            0.001          3
#> 5            0.01           3
#> 6            0.1            3
#> 7            0.001          4
#> 8            0.01           4
#> 9            0.1            4

set.seed(1)

library(doParallel)
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loading required package: iterators
#> Loading required package: parallel
cores <- detectCores(logical = FALSE)
cl <- makePSOCKcluster(cores)
registerDoParallel(cl)

xgb_model_1_res <-
    tune_grid(xgb_model_1, diabetes ~.,
              resamples = vfold_cv(df),
              grid = xgb_grid)
              # grid = tibble(scale_pos_weight = 10^c(-3:-1)))

collect_metrics(xgb_model_1_res)
#> # A tibble: 18 x 8
#>    tree_depth scale_pos_weight .metric  .estimator  mean     n std_err .config  
#>         <int>            <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>    
#>  1          2            0.001 accuracy binary     0.651    10  0.0137 Preproce~
#>  2          2            0.001 roc_auc  binary     0.5      10  0      Preproce~
#>  3          2            0.01  accuracy binary     0.651    10  0.0137 Preproce~
#>  4          2            0.01  roc_auc  binary     0.806    10  0.0158 Preproce~
#>  5          2            0.1   accuracy binary     0.714    10  0.0189 Preproce~
#>  6          2            0.1   roc_auc  binary     0.802    10  0.0153 Preproce~
#>  7          3            0.001 accuracy binary     0.651    10  0.0137 Preproce~
#>  8          3            0.001 roc_auc  binary     0.5      10  0      Preproce~
#>  9          3            0.01  accuracy binary     0.651    10  0.0137 Preproce~
#> 10          3            0.01  roc_auc  binary     0.806    10  0.0158 Preproce~
#> 11          3            0.1   accuracy binary     0.720    10  0.0208 Preproce~
#> 12          3            0.1   roc_auc  binary     0.796    10  0.0175 Preproce~
#> 13          4            0.001 accuracy binary     0.651    10  0.0137 Preproce~
#> 14          4            0.001 roc_auc  binary     0.5      10  0      Preproce~
#> 15          4            0.01  accuracy binary     0.651    10  0.0137 Preproce~
#> 16          4            0.01  roc_auc  binary     0.806    10  0.0158 Preproce~
#> 17          4            0.1   accuracy binary     0.719    10  0.0192 Preproce~
#> 18          4            0.1   roc_auc  binary     0.793    10  0.0204 Preproce~

show_best(xgb_model_1_res, metric = "roc_auc")
#> # A tibble: 5 x 8
#>   tree_depth scale_pos_weight .metric .estimator  mean     n std_err .config    
#>        <int>            <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>      
#> 1          2             0.01 roc_auc binary     0.806    10  0.0158 Preprocess~
#> 2          3             0.01 roc_auc binary     0.806    10  0.0158 Preprocess~
#> 3          4             0.01 roc_auc binary     0.806    10  0.0158 Preprocess~
#> 4          2             0.1  roc_auc binary     0.802    10  0.0153 Preprocess~
#> 5          3             0.1  roc_auc binary     0.796    10  0.0175 Preprocess~

Created on 2022-01-03 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> - Session info ---------------------------------------------------------------
#>  setting  value
#>  version  R version 4.1.2 (2021-11-01)
#>  os       Windows 10 x64 (build 19044)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_United Kingdom.1252
#>  ctype    English_United Kingdom.1252
#>  tz       Europe/Berlin
#>  date     2022-01-03
#>  pandoc   2.16.2 @ C:/Users/pinog/AppData/Local/R-MINI~1/Library/bin/ (via rmarkdown)
#> 
#> - Packages -------------------------------------------------------------------
#>  package      * version    date (UTC) lib source
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.1.0)
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.1.2)
#>  broom        * 0.7.10     2021-10-31 [1] CRAN (R 4.1.1)
#>  cachem         1.0.6      2021-08-19 [1] CRAN (R 4.1.1)
#>  cellranger     1.1.0      2016-07-27 [1] CRAN (R 4.1.0)
#>  class          7.3-19     2021-05-03 [2] CRAN (R 4.1.2)
#>  cli            3.1.0      2021-10-27 [1] CRAN (R 4.1.1)
#>  codetools      0.2-18     2020-11-04 [2] CRAN (R 4.1.2)
#>  colorspace     2.0-2      2021-06-24 [1] CRAN (R 4.1.0)
#>  conflicted     1.1.0      2021-11-26 [1] CRAN (R 4.1.1)
#>  crayon         1.4.2      2021-10-29 [1] CRAN (R 4.1.1)
#>  data.table     1.14.2     2021-09-27 [1] CRAN (R 4.1.1)
#>  DBI            1.1.2      2021-12-20 [1] CRAN (R 4.1.2)
#>  dbplyr         2.1.1      2021-04-06 [1] CRAN (R 4.1.0)
#>  dials        * 0.0.10     2021-09-10 [1] CRAN (R 4.1.1)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.1.0)
#>  digest         0.6.29     2021-12-01 [1] CRAN (R 4.1.2)
#>  doParallel   * 1.0.16     2020-10-16 [1] CRAN (R 4.1.0)
#>  dplyr        * 1.0.7      2021-06-18 [1] CRAN (R 4.1.0)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.1.0)
#>  evaluate       0.14       2019-05-28 [1] CRAN (R 4.1.0)
#>  fansi          0.5.0      2021-05-25 [1] CRAN (R 4.1.0)
#>  fastmap        1.1.0      2021-01-25 [1] CRAN (R 4.1.0)
#>  forcats      * 0.5.1      2021-01-27 [1] CRAN (R 4.1.0)
#>  foreach      * 1.5.1      2020-10-15 [1] CRAN (R 4.1.0)
#>  fs             1.5.2      2021-12-08 [1] CRAN (R 4.1.2)
#>  furrr          0.2.3      2021-06-25 [1] CRAN (R 4.1.0)
#>  future         1.23.0     2021-10-31 [1] CRAN (R 4.1.1)
#>  future.apply   1.8.1      2021-08-10 [1] CRAN (R 4.1.0)
#>  generics       0.1.1      2021-10-25 [1] CRAN (R 4.1.1)
#>  ggplot2      * 3.3.5      2021-06-25 [1] CRAN (R 4.1.0)
#>  globals        0.14.0     2020-11-22 [1] CRAN (R 4.1.0)
#>  glue           1.6.0      2021-12-17 [1] CRAN (R 4.1.2)
#>  gower          0.2.2      2020-06-23 [1] CRAN (R 4.1.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.1.0)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.1.0)
#>  hardhat        0.1.6      2021-07-14 [1] CRAN (R 4.1.0)
#>  haven          2.4.3      2021-08-04 [1] CRAN (R 4.1.0)
#>  highr          0.9        2021-04-16 [1] CRAN (R 4.1.0)
#>  hms            1.1.1      2021-09-26 [1] CRAN (R 4.1.1)
#>  htmltools      0.5.2      2021-08-25 [1] CRAN (R 4.1.1)
#>  httr           1.4.2      2020-07-20 [1] CRAN (R 4.1.0)
#>  infer        * 1.0.0      2021-08-13 [1] CRAN (R 4.1.0)
#>  ipred          0.9-12     2021-09-15 [1] CRAN (R 4.1.1)
#>  iterators    * 1.0.13     2020-10-15 [1] CRAN (R 4.1.0)
#>  jsonlite       1.7.2      2020-12-09 [1] CRAN (R 4.1.0)
#>  knitr          1.37       2021-12-16 [1] CRAN (R 4.1.2)
#>  lattice        0.20-45    2021-09-22 [2] CRAN (R 4.1.2)
#>  lava           1.6.10     2021-09-02 [1] CRAN (R 4.1.1)
#>  lhs            1.1.3      2021-09-08 [1] CRAN (R 4.1.1)
#>  lifecycle      1.0.1      2021-09-24 [1] CRAN (R 4.1.1)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.1.0)
#>  lubridate      1.8.0      2021-10-07 [1] CRAN (R 4.1.1)
#>  magrittr       2.0.1      2020-11-17 [1] CRAN (R 4.1.0)
#>  MASS           7.3-54     2021-05-03 [1] CRAN (R 4.1.0)
#>  Matrix         1.3-4      2021-06-01 [2] CRAN (R 4.1.2)
#>  memoise        2.0.1      2021-11-26 [1] CRAN (R 4.1.1)
#>  mlbench      * 2.1-3      2021-01-29 [1] CRAN (R 4.1.0)
#>  modeldata    * 0.1.1      2021-07-14 [1] CRAN (R 4.1.0)
#>  modelr         0.1.8      2020-05-19 [1] CRAN (R 4.1.0)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.1.0)
#>  nnet           7.3-16     2021-05-03 [2] CRAN (R 4.1.2)
#>  parallelly     1.30.0     2021-12-17 [1] CRAN (R 4.1.2)
#>  parsnip      * 0.1.7      2021-07-21 [1] CRAN (R 4.1.0)
#>  pillar         1.6.4      2021-10-18 [1] CRAN (R 4.1.1)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.1.0)
#>  plyr           1.8.6      2020-03-03 [1] CRAN (R 4.1.0)
#>  pROC           1.18.0     2021-09-03 [1] CRAN (R 4.1.1)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.1.0)
#>  purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.1.2)
#>  R.cache        0.15.0     2021-04-30 [1] CRAN (R 4.1.0)
#>  R.methodsS3    1.8.1      2020-08-26 [1] CRAN (R 4.1.0)
#>  R.oo           1.24.0     2020-08-26 [1] CRAN (R 4.1.0)
#>  R.utils        2.11.0     2021-09-26 [1] CRAN (R 4.1.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.1.1)
#>  Rcpp           1.0.7      2021-07-07 [1] CRAN (R 4.1.0)
#>  readr        * 2.1.1      2021-11-30 [1] CRAN (R 4.1.2)
#>  readxl         1.3.1      2019-03-13 [1] CRAN (R 4.1.0)
#>  recipes      * 0.1.17     2021-09-27 [1] CRAN (R 4.1.1)
#>  reprex         2.0.1      2021-08-05 [1] CRAN (R 4.1.0)
#>  rlang          0.4.12     2021-10-18 [1] CRAN (R 4.1.1)
#>  rmarkdown      2.11       2021-09-14 [1] CRAN (R 4.1.1)
#>  rpart          4.1-15     2019-04-12 [2] CRAN (R 4.1.2)
#>  rsample      * 0.1.1      2021-11-08 [1] CRAN (R 4.1.1)
#>  rstudioapi     0.13       2020-11-12 [1] CRAN (R 4.1.0)
#>  rvest          1.0.2      2021-10-16 [1] CRAN (R 4.1.1)
#>  scales       * 1.1.1      2020-05-11 [1] CRAN (R 4.1.0)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.1.2)
#>  stringi        1.7.6      2021-11-29 [1] CRAN (R 4.1.1)
#>  stringr      * 1.4.0      2019-02-10 [1] CRAN (R 4.1.0)
#>  styler         1.6.2      2021-09-23 [1] CRAN (R 4.1.1)
#>  survival       3.2-13     2021-08-24 [1] CRAN (R 4.1.1)
#>  tibble       * 3.1.6      2021-11-07 [1] CRAN (R 4.1.1)
#>  tidymodels   * 0.1.4      2021-10-01 [1] CRAN (R 4.1.1)
#>  tidyr        * 1.1.4      2021-09-27 [1] CRAN (R 4.1.1)
#>  tidyselect     1.1.1      2021-04-30 [1] CRAN (R 4.1.0)
#>  tidyverse    * 1.3.1      2021-04-15 [1] CRAN (R 4.1.0)
#>  timeDate       3043.102   2018-02-21 [1] CRAN (R 4.1.0)
#>  tune         * 0.1.6      2021-07-21 [1] CRAN (R 4.1.2)
#>  tzdb           0.2.0      2021-10-27 [1] CRAN (R 4.1.1)
#>  utf8           1.2.2      2021-07-24 [1] CRAN (R 4.1.0)
#>  vctrs          0.3.8      2021-04-29 [1] CRAN (R 4.1.0)
#>  withr          2.4.3      2021-11-30 [1] CRAN (R 4.1.2)
#>  workflows    * 0.2.4      2021-10-12 [1] CRAN (R 4.1.1)
#>  workflowsets * 0.1.0      2021-07-22 [1] CRAN (R 4.1.0)
#>  xfun           0.29       2021-12-14 [1] CRAN (R 4.1.2)
#>  xgboost        1.5.0.2    2021-11-21 [1] CRAN (R 4.1.1)
#>  xml2           1.3.3      2021-11-30 [1] CRAN (R 4.1.2)
#>  yaml           2.2.1      2020-02-01 [1] CRAN (R 4.1.0)
#>  yardstick    * 0.0.9      2021-11-22 [1] CRAN (R 4.1.1)
#> 
#>  [1] D:/R/library
#>  [2] C:/Program Files/R/R-4.1.2/library
#> 
#> ------------------------------------------------------------------------------

Ah, thanks! I had updated, but I think I must have still managed to load the old version due to multiple libPaths.

Appreciate your help!

Hi @Gus & others,

I still don't quite understand this. If I put the scale_pos_weight = tune() statement in the set_engine part of the model specification, how do I actually have it tune over some space when using the default (unspecified) grid?

xgb_rs.all <- tune_race_anova(
  pt_xgb_workflow2,
  resamples = cv_all,
  grid = 100,
  metrics = my_metrics,
  control = control_race(verbose_elim = TRUE)
)

The above fails (Warning message: This tuning result has notes. Example notes on model fitting include: internal: Error: Can't subset columns that don't exist. ✖ Column `scale_pos_weight` doesn't exist. ).

  1. Is there a default range over which tidymodels will search the scale_pos_weight space? I do know that the general recommendation is neg/pos.
  2. How do I specify the range for one hyperparameter if I want to make use of the defaults for other hyperparameters?
  3. Where can I find what space is being considered for each hyperparameter?

Thanks!!!!

You can find information on the parameters by

  • loading dials and using ?scale_pos_weight
  • checking the pkgdown site reference page.

I'm not sure why the error occurs. Maybe an older version of dials? Can you show the results of

library(tidymodels)

without using reprex (and maybe a full reprex to show the error)?

@Max ,

Thanks for your patience with my basic questions here! I have spent loads of time at the wonderful pkgdown site but I don't always end up finding all of the relevant bits (that's obviously on me! it's beautifully organized -- apologies for asking questions that are answered there. I'm always happy to read tutorials and documentation that you link to).

Information on parameters: E.g., for parameters that have an unknown() bit, how can you see what value was filled in (in one of the resulting objects created), or where in the documentation does it list the equation for calculating, eg, the max value for mtry? (I assume the max value for mtry is the number of columns, but I see that there are other parameters that I haven't touched that also have unknown() listed, like sample_size, and would love to know where to see how default values are chosen, if unspecified by the user.)

My tidymodels seems to be up to date -- I am sure this is a user error on where I'm specifying tune = scale_pos_weight() or my lack of specification of a range.

Scale pos weight custom range: The default scale_pos_weight range doesn't work for me, but I haven't yet made sense of what needs to happen to provide custom range (eg finalize? most of the examples I've seen have hyperparameters that are either tuned over default ranges or set, not tuned over custom ranges).

library(mlbench)
library(forcats)
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.0.5
#> Warning in system("timedatectl", intern = TRUE): running command 'timedatectl'
#> had status 1
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
#> Warning: package 'broom' was built under R version 4.0.5
#> Warning: package 'dials' was built under R version 4.0.5
#> Warning: package 'infer' was built under R version 4.0.5
#> Warning: package 'modeldata' was built under R version 4.0.5
#> Warning: package 'parsnip' was built under R version 4.0.5
#> Warning: package 'recipes' was built under R version 4.0.5
#> Warning: package 'rsample' was built under R version 4.0.5
#> Warning: package 'tibble' was built under R version 4.0.5
#> Warning: package 'tune' was built under R version 4.0.5
#> Warning: package 'workflows' was built under R version 4.0.5
#> Warning: package 'workflowsets' was built under R version 4.0.5
library(finetune)
#> Warning: package 'finetune' was built under R version 4.0.5

data("PimaIndiansDiabetes")

set.seed(24)

df <- PimaIndiansDiabetes %>%
  mutate(diabetes = fct_relevel(diabetes, 'pos'))

xgb_rec <- recipe(diabetes ~ ., data = df)

xgb_spec <- boost_tree(
  trees = tune()) %>%
  set_engine("xgboost", scale_pos_weight = tune()) %>%
  set_mode("classification")

resamples_cv <- vfold_cv(df, v = 5)
my_metrics <- metric_set(mn_log_loss, roc_auc, pr_auc)

xgb_wf <- workflow() %>%
  add_recipe(xgb_rec) %>%
  add_model(xgb_spec)

xgb_rs <- tune_race_anova(
  xgb_wf,
  resamples = resamples_cv,
  grid = 10,
  metrics = my_metrics,
  control = control_race(verbose_elim = TRUE)
)
#> Error: The workflow has arguments to be tuned that are missing some parameter objects: 'scale_pos_weight'

sessionInfo()
#> R version 4.0.4 (2021-02-15)
#> Platform: x86_64-pc-linux-gnu (64-bit)
#> Running under: CloudForms
#> 
#> Matrix products: default
#> BLAS:   /usr/local/lib64/R/lib/libRblas.so
#> LAPACK: /usr/local/lib64/R/lib/libRlapack.so
#> 
#> locale:
#>  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
#>  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
#>  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] finetune_0.1.0     yardstick_0.0.8    workflowsets_0.1.0 workflows_0.2.4   
#>  [5] tune_0.1.6         tidyr_1.1.4        tibble_3.1.5       rsample_0.1.1     
#>  [9] recipes_0.1.17     purrr_0.3.4        parsnip_0.1.7      modeldata_0.1.1   
#> [13] infer_1.0.0        ggplot2_3.3.5      dplyr_1.0.7        dials_0.0.10      
#> [17] scales_1.1.1       broom_0.7.10       tidymodels_0.1.4   forcats_0.5.1     
#> [21] mlbench_2.1-3     
#> 
#> loaded via a namespace (and not attached):
#>  [1] nlme_3.1-152       fs_1.5.0           lubridate_1.7.10   DiceDesign_1.9    
#>  [5] tools_4.0.4        backports_1.2.1    utf8_1.2.2         R6_2.5.1          
#>  [9] rpart_4.1-15       DBI_1.1.1          colorspace_2.0-2   nnet_7.3-15       
#> [13] withr_2.4.2        tidyselect_1.1.1   compiler_4.0.4     cli_3.1.0         
#> [17] stringr_1.4.0      digest_0.6.28      minqa_1.2.4        rmarkdown_2.11    
#> [21] pkgconfig_2.0.3    htmltools_0.5.2    parallelly_1.24.0  lme4_1.1-27.1     
#> [25] styler_1.4.1       lhs_1.1.1          fastmap_1.1.0      highr_0.9         
#> [29] rlang_0.4.12       rstudioapi_0.13    generics_0.1.1     jsonlite_1.7.2    
#> [33] magrittr_2.0.1     Matrix_1.3-2       Rcpp_1.0.7         munsell_0.5.0     
#> [37] fansi_0.5.0        GPfit_1.0-8        lifecycle_1.0.1    furrr_0.2.2       
#> [41] stringi_1.7.5      pROC_1.17.0.1      yaml_2.2.1         MASS_7.3-53       
#> [45] plyr_1.8.6         grid_4.0.4         parallel_4.0.4     listenv_0.8.0     
#> [49] crayon_1.4.2       lattice_0.20-41    splines_4.0.4      knitr_1.36        
#> [53] pillar_1.6.3       boot_1.3-26        xgboost_1.4.1.1    codetools_0.2-18  
#> [57] reprex_2.0.0       glue_1.5.1         evaluate_0.14      data.table_1.14.0 
#> [61] nloptr_1.2.2.2     vctrs_0.3.8        foreach_1.5.1      gtable_0.3.0      
#> [65] future_1.21.0      assertthat_0.2.1   xfun_0.28          gower_0.2.2       
#> [69] prodlim_2019.11.13 class_7.3-18       survival_3.2-7     timeDate_3043.102 
#> [73] iterators_1.0.13   hardhat_0.1.6      lava_1.6.9         globals_0.14.0    
#> [77] ellipsis_0.3.2     ipred_0.9-12

Created on 2022-01-04 by the reprex package (v2.0.0)

Thank you, thank you!!