Importance weight in XGBoost

Hi, I just want to confirm: will the importance weights be passed to the weight option of xgb.train()? I see no weight in my result.

library(tidymodels)
library(hardhat)

boost_tree_spec <- function(engine = "xgboost", mode = "classification", ...) {
    boost_tree(
        trees = tune()
    ) %>%
        set_mode(mode) %>%
        set_engine(engine, ...)
}

xgboost_recipe <- function(data) {
    data %>%
        recipes::recipe(signal ~ ., data = data) %>%
        update_role(symbol, new_role = "info") %>%
        step_date(date, features = "dow", keep_original_cols = FALSE) %>%
        step_dummy(all_nominal_predictors())
}

final_nthread <- tune_nthread <- 0
tune_tree_method <- "hist"
tune_max_bin <- 256
final_max_bin <- 512
final_tree_method <- "hist"

data <- tibble(
    date = sample(
        seq(as.Date("1999/01/01"), as.Date("2000/01/01"), by = "day"),
        100
    ),
    f1 = sample(0:1000, 100),
    symbol = sample(c("s1", "s2", "s3", "s4"), 100, replace = TRUE),
    importance = importance_weights(sample(0:100, 100)),
    signal = sample(c("open", "close"), 100, replace = TRUE)
)

data %>%
    glimpse()
#> Rows: 100
#> Columns: 5
#> $ date       <date> 1999-01-25, 1999-03-06, 1999-06-24, 1999-10-23, 1999-08-19…
#> $ f1         <int> 314, 814, 983, 888, 620, 750, 913, 297, 439, 155, 460, 698,…
#> $ symbol     <chr> "s1", "s4", "s3", "s2", "s4", "s3", "s2", "s2", "s2", "s2",…
#> $ importance <imp_wts> 50, 0, 89, 74, 98, 82, 65, 31, 75, 32, 100, 52, 35, 71,…
#> $ signal     <chr> "open", "close", "open", "open", "open", "close", "close", …

split <- data %>%
    initial_split(strata = signal)

tune_data <- split %>%
    training() %>%
    print()
#> # A tibble: 74 × 5
#>    date          f1 symbol importance signal
#>    <date>     <int> <chr>   <imp_wts> <chr> 
#>  1 1999-03-06   814 s4              0 close 
#>  2 1999-08-02   750 s3             82 close 
#>  3 1999-01-26   913 s2             65 close 
#>  4 1999-11-02   297 s2             31 close 
#>  5 1999-07-26   460 s1            100 close 
#>  6 1999-01-23   338 s2             71 close 
#>  7 1999-12-18   465 s3             39 close 
#>  8 1999-02-17   294 s2             64 close 
#>  9 1999-11-26   749 s4              7 close 
#> 10 1999-09-24   333 s1             37 close 
#> # … with 64 more rows

tune_result <- workflow() %>%
    add_model(
        boost_tree_spec(
            nthread = tune_nthread,
            tree_method = tune_tree_method,
            max_bin = tune_max_bin
        )
    ) %>%
    add_recipe(xgboost_recipe(training(split))) %>%
    add_case_weights(importance) %>%
    tune_grid(
        resamples = vfold_cv(tune_data, v = 2, strata = signal),
        grid = 5,
        control = control_grid(
            verbose = TRUE
        ),
        metrics = metric_set(roc_auc)
    )
#> i Fold1: preprocessor 1/1
#> ✓ Fold1: preprocessor 1/1
#> i Fold1: preprocessor 1/1, model 1/1
#> ✓ Fold1: preprocessor 1/1, model 1/1
#> i Fold1: preprocessor 1/1, model 1/1 (predictions)
#> i Fold2: preprocessor 1/1
#> ✓ Fold2: preprocessor 1/1
#> i Fold2: preprocessor 1/1, model 1/1
#> ✓ Fold2: preprocessor 1/1, model 1/1
#> i Fold2: preprocessor 1/1, model 1/1 (predictions)

best_params <- tune_result %>%
    select_best()

final_workflow <- workflow() %>%
    add_model(
        boost_tree_spec(
            nthread = final_nthread,
            tree_method = final_tree_method,
            max_bin = final_max_bin
        )
    ) %>%
    add_recipe(xgboost_recipe(training(split))) %>%
    add_case_weights(importance) %>%
    finalize_workflow(best_params)

final_workflow %>%
    last_fit(split) %>%
    collect_metrics() %>%
    print()
#> # A tibble: 2 × 4
#>   .metric  .estimator .estimate .config             
#>   <chr>    <chr>          <dbl> <chr>               
#> 1 accuracy binary         0.5   Preprocessor1_Model1
#> 2 roc_auc  binary         0.542 Preprocessor1_Model1

model <- final_workflow %>%
    fit(training(split)) %>%
    print()
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 2 Recipe Steps
#> 
#> • step_date()
#> • step_dummy()
#> 
#> ── Case Weights ────────────────────────────────────────────────────────────────
#> importance
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> ##### xgb.Booster
#> raw: 959.3 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "binary:logistic"), data = x$data, 
#>     nrounds = 1019L, watchlist = x$watchlist, verbose = 0, nthread = 0, 
#>     tree_method = "hist", max_bin = 512)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "binary:logistic", nthread = "0", tree_method = "hist", max_bin = "512", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 7 
#> niter: 1019
#> nfeatures : 7 
#> evaluation_log:
#>     iter training_logloss
#>        1      0.564311269
#>        2      0.496895742
#> ---                      
#>     1018      0.001333734
#>     1019      0.001333456

Created on 2022-07-16 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.1 (2022-06-23)
#>  os       macOS Monterey 12.4
#>  system   x86_64, darwin21.5.0
#>  ui       unknown
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Asia/Shanghai
#>  date     2022-07-16
#>  pandoc   2.18 @ /usr/local/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.2.1)
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.2.1)
#>  broom        * 1.0.0      2022-07-01 [1] CRAN (R 4.2.1)
#>  class          7.3-20     2022-01-16 [2] CRAN (R 4.2.1)
#>  cli            3.3.0      2022-04-25 [1] CRAN (R 4.2.1)
#>  codetools      0.2-18     2020-11-04 [2] CRAN (R 4.2.1)
#>  colorspace     2.0-3      2022-02-21 [1] CRAN (R 4.2.1)
#>  crayon         1.5.1      2022-03-26 [1] CRAN (R 4.2.1)
#>  data.table     1.14.2     2021-09-27 [1] CRAN (R 4.2.1)
#>  DBI            1.1.3      2022-06-18 [1] CRAN (R 4.2.1)
#>  dials        * 1.0.0      2022-06-14 [1] CRAN (R 4.2.1)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.2.1)
#>  digest         0.6.29     2021-12-01 [1] CRAN (R 4.2.1)
#>  dplyr        * 1.0.9      2022-04-28 [1] CRAN (R 4.2.1)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.2.1)
#>  evaluate       0.15       2022-02-18 [1] CRAN (R 4.2.1)
#>  fansi          1.0.3      2022-03-24 [1] CRAN (R 4.2.1)
#>  fastmap        1.1.0      2021-01-25 [1] CRAN (R 4.2.1)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.2.1)
#>  fs             1.5.2      2021-12-08 [1] CRAN (R 4.2.1)
#>  furrr          0.3.0      2022-05-04 [1] CRAN (R 4.2.1)
#>  future         1.26.1     2022-05-27 [1] CRAN (R 4.2.1)
#>  future.apply   1.9.0      2022-04-25 [1] CRAN (R 4.2.1)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.2.1)
#>  ggplot2      * 3.3.6      2022-05-03 [1] CRAN (R 4.2.1)
#>  globals        0.15.1     2022-06-24 [1] CRAN (R 4.2.1)
#>  glue           1.6.2      2022-02-24 [1] CRAN (R 4.2.1)
#>  gower          1.0.0      2022-02-03 [1] CRAN (R 4.2.1)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.2.1)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.2.1)
#>  hardhat      * 1.2.0      2022-06-30 [1] CRAN (R 4.2.1)
#>  highr          0.9        2021-04-16 [1] CRAN (R 4.2.1)
#>  htmltools      0.5.2      2021-08-25 [1] CRAN (R 4.2.1)
#>  infer        * 1.0.2      2022-06-10 [1] CRAN (R 4.2.1)
#>  ipred          0.9-13     2022-06-02 [1] CRAN (R 4.2.1)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.2.1)
#>  jsonlite       1.8.0      2022-02-22 [1] CRAN (R 4.2.1)
#>  knitr          1.39       2022-04-26 [1] CRAN (R 4.2.1)
#>  lattice        0.20-45    2021-09-22 [2] CRAN (R 4.2.1)
#>  lava           1.6.10     2021-09-02 [1] CRAN (R 4.2.1)
#>  lhs            1.1.5      2022-03-22 [1] CRAN (R 4.2.1)
#>  lifecycle      1.0.1      2021-09-24 [1] CRAN (R 4.2.1)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.2.1)
#>  lubridate      1.8.0      2021-10-07 [1] CRAN (R 4.2.1)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.2.1)
#>  MASS           7.3-58     2022-07-14 [2] CRAN (R 4.2.1)
#>  Matrix         1.4-1      2022-03-23 [2] CRAN (R 4.2.1)
#>  modeldata    * 1.0.0      2022-07-01 [1] CRAN (R 4.2.1)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.2.1)
#>  nnet           7.3-17     2022-01-16 [2] CRAN (R 4.2.1)
#>  parallelly     1.32.0     2022-06-07 [1] CRAN (R 4.2.1)
#>  parsnip      * 1.0.0      2022-06-16 [1] CRAN (R 4.2.1)
#>  pillar         1.7.0      2022-02-01 [1] CRAN (R 4.2.1)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.2.1)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.2.1)
#>  purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.2.1)
#>  R.cache        0.15.0     2021-04-30 [1] CRAN (R 4.2.1)
#>  R.methodsS3    1.8.2      2022-06-13 [1] CRAN (R 4.2.1)
#>  R.oo           1.25.0     2022-06-12 [1] CRAN (R 4.2.1)
#>  R.utils        2.12.0     2022-06-28 [1] CRAN (R 4.2.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.2.1)
#>  Rcpp           1.0.9      2022-07-08 [1] CRAN (R 4.2.1)
#>  recipes      * 1.0.1      2022-07-07 [1] CRAN (R 4.2.1)
#>  reprex         2.0.1      2021-08-05 [1] CRAN (R 4.2.1)
#>  rlang          1.0.4      2022-07-12 [1] CRAN (R 4.2.1)
#>  rmarkdown      2.14       2022-04-25 [1] CRAN (R 4.2.1)
#>  rpart          4.1.16     2022-01-24 [2] CRAN (R 4.2.1)
#>  rsample      * 1.0.0      2022-06-24 [1] CRAN (R 4.2.1)
#>  scales       * 1.2.0      2022-04-13 [1] CRAN (R 4.2.1)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.2.1)
#>  stringi        1.7.8      2022-07-11 [1] CRAN (R 4.2.1)
#>  stringr        1.4.0      2019-02-10 [1] CRAN (R 4.2.1)
#>  styler         1.7.0      2022-03-13 [1] CRAN (R 4.2.1)
#>  survival       3.3-1      2022-03-03 [2] CRAN (R 4.2.1)
#>  tibble       * 3.1.7      2022-05-03 [1] CRAN (R 4.2.1)
#>  tidymodels   * 1.0.0      2022-07-13 [1] CRAN (R 4.2.1)
#>  tidyr        * 1.2.0      2022-02-01 [1] CRAN (R 4.2.1)
#>  tidyselect     1.1.2      2022-02-21 [1] CRAN (R 4.2.1)
#>  timeDate       3043.102   2018-02-21 [1] CRAN (R 4.2.1)
#>  tune         * 1.0.0      2022-07-07 [1] CRAN (R 4.2.1)
#>  utf8           1.2.2      2021-07-24 [1] CRAN (R 4.2.1)
#>  vctrs          0.4.1      2022-04-13 [1] CRAN (R 4.2.1)
#>  withr          2.5.0      2022-03-03 [1] CRAN (R 4.2.1)
#>  workflows    * 1.0.0      2022-07-05 [1] CRAN (R 4.2.1)
#>  workflowsets * 1.0.0      2022-07-12 [1] CRAN (R 4.2.1)
#>  xfun           0.31       2022-05-10 [1] CRAN (R 4.2.1)
#>  xgboost      * 1.6.0.1    2022-04-16 [1] CRAN (R 4.2.1)
#>  yaml           2.3.5      2022-02-21 [1] CRAN (R 4.2.1)
#>  yardstick    * 1.0.0      2022-06-06 [1] CRAN (R 4.2.1)
#> 
#>  [1] /usr/local/lib/R/4.2/site-library
#>  [2] /usr/local/Cellar/r/4.2.1/lib/R/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.