Passing arguments to parsnip::set_engine()

In the code below, I am trying to pass a vector of weights to xgboost using the tidymodels framework. My understanding is that arguments can be passed to the underlying model functions with parsnip::set_engine(). However, I am unclear on how to appropriately pass the weights argument to xgboost::xgb.train(). I've tried a few ideas and tinkered with the variable roles, but have not had success. Any help is appreciated! :slight_smile:

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.0.2
#> -- Attaching packages -------------------------------------------------------------- tidymodels 0.1.1 --
#> v broom     0.7.0      v recipes   0.1.13
#> v dials     0.0.8      v rsample   0.0.7 
#> v dplyr     1.0.0      v tibble    3.0.3 
#> v ggplot2   3.3.2      v tidyr     1.1.0 
#> v infer     0.5.3      v tune      0.1.1 
#> v modeldata 0.0.2      v workflows 0.1.2 
#> v parsnip   0.1.2      v yardstick 0.0.7 
#> v purrr     0.3.4
#> Warning: package 'broom' was built under R version 4.0.2
#> Warning: package 'dials' was built under R version 4.0.2
#> Warning: package 'scales' was built under R version 4.0.2
#> Warning: package 'dplyr' was built under R version 4.0.2
#> Warning: package 'ggplot2' was built under R version 4.0.2
#> Warning: package 'infer' was built under R version 4.0.2
#> Warning: package 'modeldata' was built under R version 4.0.2
#> Warning: package 'parsnip' was built under R version 4.0.2
#> Warning: package 'purrr' was built under R version 4.0.2
#> Warning: package 'recipes' was built under R version 4.0.2
#> Warning: package 'rsample' was built under R version 4.0.2
#> Warning: package 'tibble' was built under R version 4.0.2
#> Warning: package 'tidyr' was built under R version 4.0.2
#> Warning: package 'tune' was built under R version 4.0.2
#> Warning: package 'workflows' was built under R version 4.0.2
#> Warning: package 'yardstick' was built under R version 4.0.2
#> -- Conflicts ----------------------------------------------------------------- tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()

data <- tibble(outcome = rnorm(3000, 100, 15),
               pred_1 = outcome + rnorm(3000, 0, .6),
               pred_2 = sample(c("lev1", "lev2", "lev3"), 
                               size = 3000, 
                               replace = TRUE),
               the_weights = round(runif(3000, 1, 7), 0))

data <- mutate_if(data, is.character, factor) 

data_split <- initial_split(data, 
                            prop = .75, 
                            strata = outcome) 

training <- training(data_split) 
testing <- testing(data_split)

my_recipe <- recipe(outcome ~ ., data = training) %>%  
  step_nzv(all_nominal()) %>% 
  step_dummy(all_nominal(), one_hot = TRUE) %>% 
  update_role(the_weights, new_role = "weights") #do i need a new role here?
  
xgb_spec <- boost_tree(trees = 200, 
                       tree_depth = tune(),
                       mtry = tune(),        
                       learn_rate = tune()) %>% 
  set_engine("xgboost", params = list(weight = the_weights)) %>% #attempting to pass the weights to xgb.train()
  set_mode("regression")

xgb_grid <- grid_latin_hypercube(
  tree_depth(),
  finalize(mtry(), training), 
  learn_rate(),
  size = 6
)

xgb_wf <- workflow() %>%
  add_recipe(my_recipe) %>% 
  add_model(xgb_spec) 

xgb_folds <- vfold_cv(training, strata = outcome, v = 2)

xgb_res <- tune_grid(
  object = xgb_wf,
  grid = xgb_grid,
  resamples = xgb_folds,
  control = control_grid(save_pred = TRUE)
)
#> x Fold1: model 1/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 2/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 3/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 4/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 5/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold1: model 6/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 1/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 2/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 3/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 4/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 5/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> x Fold2: model 6/6: Error in ~list(weight = the_weights): object 'the_weights' no...
#> Warning: All models failed in tune_grid(). See the `.notes` column.

Created on 2020-08-05 by the reprex package (v0.3.0)

Update: tune::tune_grid() does not throw an error if I specify the engine as follows:

set_engine("xgboost", params = list(weight = .$the_weights))

@julia @Max @davis any chance you can weigh in?

Am I specifying this weighted xgboost model correctly above using update_role() in the recipe and params in set_engine()?

We don't support case weights yet; it's on the list but not there yet.

Thank you for your reply, @Max!

In the meantime, I can use tidyr::uncount() to change the weights into rows. Do you have a sense as to how this functionality might be prioritized?

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.