fit_split missing model argument when using a workflow

Running into an issue with custom fit_split() function @apreshill shows in this great tidymodels training resource. I get an error complaining that the model argument is missing, but I think it is defined in the workflow.

library(modeldata)
data("stackoverflow")
library(tidyverse)
library(tidymodels)

set.seed(100) # Important!

# make smaller to save time
  so_split <- initial_split(sample_n(stackoverflow, size = 300), 
                            strata = Remote)
  so_train <- training(so_split)
  so_test  <- testing(so_split)

# again, simpler so runs faster
  so_folds <- vfold_cv(so_train, v = 2, strata = Remote)

# recipe
  tune_rec <- recipe(Remote ~ ., 
                     data = so_train) %>% 
    step_dummy(all_nominal(), -all_outcomes()) %>% 
    step_lincomb(all_predictors()) %>% 
    step_downsample(Remote, under_ratio = tune())

# model
  tune_spec <-
    rand_forest(mtry = tune(),
                min_n = tune()) %>% 
    set_engine("ranger") %>% 
    set_mode("classification")

# workflow
  tuneboth_wf <-
  workflow() %>% 
    add_recipe(tune_rec) %>% 
    add_model(tune_spec)

# tuning parameters
  tuneboth_param <- parameters(tuneboth_wf)

  tuneboth_param <- 
  tuneboth_param %>% 
  # Pick an upper bound for mtry: 
    update(mtry = mtry(c(1, 20)))

# fit
  results <- 
  tuneboth_wf %>% 
    tune_grid(resamples = so_folds, 
              param_info = tuneboth_param)

# get best tuning parameters
  best <- 
  results %>% 
    select_best(metric = "roc_auc")

# define final workflow
  wf_final <- 
    tuneboth_wf %>%
    finalize_workflow(best)

# re-run with best
  fit_split <- function(formula, model, split, ...) {
    wf <- workflows::add_model(
      workflows::add_formula(workflows::workflow(), 
                             formula, 
                             blueprint = 
                               hardhat::default_formula_blueprint(
                                 indicators = FALSE, 
                                 allow_novel_levels = TRUE)), 
      model)
    tune::last_fit(wf, split, ...)
  }
  
results_best_test <-
  wf_final %>%
  fit_split(split = so_split,
            metrics = metric_set(roc_auc, sens, spec))

What kind of object is wf_final? Is it possible to post a relevant sample of it?

It's a workflow.

> wf_final
══ Workflow ═════════════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ─────────────────────────────────────────────────────────────────────────────────────
3 Recipe Steps

● step_dummy()
● step_lincomb()
● step_downsample()

── Model ────────────────────────────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  mtry = 9
  min_n = 19

Computational engine: ranger

What about str(wf_final), @ericpgreen ?

> str(wf_final)
List of 4
 $ pre    :List of 2
  ..$ actions:List of 1
  .. ..$ recipe:List of 2
  .. .. ..$ recipe   :List of 6
  .. .. .. ..$ var_info :Classes ‘tbl_df’, ‘tbl’ and 'data.frame':	21 obs. of  4 variables:
  .. .. .. .. ..$ variable: chr [1:21] "Country" "Salary" "YearsCodedJob" "OpenSource" ...
  .. .. .. .. ..$ type    : chr [1:21] "nominal" "numeric" "numeric" "numeric" ...
  .. .. .. .. ..$ role    : chr [1:21] "predictor" "predictor" "predictor" "predictor" ...
  .. .. .. .. ..$ source  : chr [1:21] "original" "original" "original" "original" ...
  .. .. .. ..$ term_info:Classes ‘tbl_df’, ‘tbl’ and 'data.frame':	21 obs. of  4 variables:
  .. .. .. .. ..$ variable: chr [1:21] "Country" "Salary" "YearsCodedJob" "OpenSource" ...
  .. .. .. .. ..$ type    : chr [1:21] "nominal" "numeric" "numeric" "numeric" ...
  .. .. .. .. ..$ role    : chr [1:21] "predictor" "predictor" "predictor" "predictor" ...
  .. .. .. .. ..$ source  : chr [1:21] "original" "original" "original" "original" ...
  .. .. .. ..$ steps    :List of 3
  .. .. .. .. ..$ :List of 9
  .. .. .. .. .. ..$ terms   :List of 2
  .. .. .. .. .. .. ..$ : language ~all_nominal()
  .. .. .. .. .. .. .. ..- attr(*, ".Environment")=<environment: 0x7fc8986e40d8> 
  .. .. .. .. .. .. ..$ : language ~-all_outcomes()
  .. .. .. .. .. .. .. ..- attr(*, ".Environment")=<environment: 0x7fc8986e40d8> 
  .. .. .. .. .. .. ..- attr(*, "class")= chr "quosures"
  .. .. .. .. .. ..$ role    : chr "predictor"
  .. .. .. .. .. ..$ trained : logi FALSE
  .. .. .. .. .. ..$ one_hot : logi FALSE
  .. .. .. .. .. ..$ preserve: logi FALSE
  .. .. .. .. .. ..$ naming  :function (var, lvl, ordinal = FALSE, sep = "_")  
  .. .. .. .. .. ..$ levels  : NULL
  .. .. .. .. .. ..$ skip    : logi FALSE
  .. .. .. .. .. ..$ id      : chr "dummy_46eru"
  .. .. .. .. .. ..- attr(*, "class")= chr [1:2] "step_dummy" "step"
  .. .. .. .. ..$ :List of 7
  .. .. .. .. .. ..$ terms    :List of 1
  .. .. .. .. .. .. ..$ : language ~all_predictors()
  .. .. .. .. .. .. .. ..- attr(*, ".Environment")=<environment: 0x7fc8986fec00> 
  .. .. .. .. .. .. ..- attr(*, "class")= chr "quosures"
  .. .. .. .. .. ..$ role     : logi NA
  .. .. .. .. .. ..$ trained  : logi FALSE
  .. .. .. .. .. ..$ max_steps: num 5
  .. .. .. .. .. ..$ removals : NULL
  .. .. .. .. .. ..$ skip     : logi FALSE
  .. .. .. .. .. ..$ id       : chr "lincomp_n6IDq"
  .. .. .. .. .. ..- attr(*, "class")= chr [1:2] "step_lincomb" "step"
  .. .. .. .. ..$ :List of 11
  .. .. .. .. .. ..$ terms      :List of 1
  .. .. .. .. .. .. ..$ : language ~Remote
  .. .. .. .. .. .. .. ..- attr(*, ".Environment")=<environment: 0x7fc898737028> 
  .. .. .. .. .. .. ..- attr(*, "class")= chr "quosures"
  .. .. .. .. .. ..$ under_ratio: num 0.991
  .. .. .. .. .. ..$ ratio      : logi NA
  .. .. .. .. .. ..$ role       : logi NA
  .. .. .. .. .. ..$ trained    : logi FALSE
  .. .. .. .. .. ..$ column     : NULL
  .. .. .. .. .. ..$ target     : logi NA
  .. .. .. .. .. ..$ skip       : logi TRUE
  .. .. .. .. .. ..$ id         : chr "downsample_a8xcH"
  .. .. .. .. .. ..$ seed       : int 24391
  .. .. .. .. .. ..$ id         : chr "downsample_a8xcH"
  .. .. .. .. .. ..- attr(*, "class")= chr [1:2] "step_downsample" "step"
  .. .. .. ..$ template :Classes ‘tbl_df’, ‘tbl’ and 'data.frame':	225 obs. of  21 variables:
  .. .. .. .. ..$ Country                             : Factor w/ 5 levels "Canada","Germany",..: 4 5 3 3 4 4 5 1 5 2 ...
  .. .. .. .. ..$ Salary                              : num [1:225] 50000 150000 17619 12480 87500 ...
  .. .. .. .. ..$ YearsCodedJob                       : int [1:225] 7 20 2 3 10 3 3 4 4 5 ...
  .. .. .. .. ..$ OpenSource                          : num [1:225] 0 0 0 1 1 0 0 1 1 0 ...
  .. .. .. .. ..$ Hobby                               : num [1:225] 0 0 1 1 1 1 1 1 1 1 ...
  .. .. .. .. ..$ CompanySizeNumber                   : num [1:225] 1000 20 1 10 100 100 100 20 10 100 ...
  .. .. .. .. ..$ CareerSatisfaction                  : int [1:225] 7 7 8 6 5 8 10 10 9 10 ...
  .. .. .. .. ..$ Data_scientist                      : num [1:225] 0 0 0 0 0 1 0 0 0 0 ...
  .. .. .. .. ..$ Database_administrator              : num [1:225] 1 0 0 0 0 0 0 1 1 0 ...
  .. .. .. .. ..$ Desktop_applications_developer      : num [1:225] 1 0 0 0 0 0 0 0 0 0 ...
  .. .. .. .. ..$ Developer_with_stats_math_background: num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
  .. .. .. .. ..$ DevOps                              : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
  .. .. .. .. ..$ Embedded_developer                  : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
  .. .. .. .. ..$ Graphic_designer                    : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
  .. .. .. .. ..$ Graphics_programming                : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
  .. .. .. .. ..$ Machine_learning_specialist         : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
  .. .. .. .. ..$ Mobile_developer                    : num [1:225] 0 1 1 0 0 0 0 0 1 1 ...
  .. .. .. .. ..$ Quality_assurance_engineer          : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
  .. .. .. .. ..$ Systems_administrator               : num [1:225] 0 0 0 0 0 0 0 0 1 0 ...
  .. .. .. .. ..$ Web_developer                       : num [1:225] 0 0 0 1 1 0 1 1 1 0 ...
  .. .. .. .. ..$ Remote                              : Factor w/ 2 levels "Remote","Not remote": 2 1 1 2 2 2 2 2 2 2 ...
  .. .. .. ..$ levels   : NULL
  .. .. .. ..$ retained : logi NA
  .. .. .. ..- attr(*, "class")= chr "recipe"
  .. .. ..$ blueprint:List of 8
  .. .. .. ..$ mold              :List of 2
  .. .. .. .. ..$ clean  :function (blueprint, data)  
  .. .. .. .. ..$ process:function (blueprint, data)  
  .. .. .. ..$ forge             :List of 2
  .. .. .. .. ..$ clean  :function (blueprint, new_data, outcomes)  
  .. .. .. .. ..$ process:function (blueprint, predictors, outcomes, extras)  
  .. .. .. ..$ intercept         : logi FALSE
  .. .. .. ..$ allow_novel_levels: logi FALSE
  .. .. .. ..$ ptypes            : NULL
  .. .. .. ..$ fresh             : logi FALSE
  .. .. .. ..$ recipe            : NULL
  .. .. .. ..$ extra_role_ptypes : NULL
  .. .. .. ..- attr(*, "class")= chr [1:3] "default_recipe_blueprint" "recipe_blueprint" "hardhat_blueprint"
  .. .. ..- attr(*, "class")= chr [1:3] "action_recipe" "action_pre" "action"
  ..$ mold   : NULL
  ..- attr(*, "class")= chr [1:2] "stage_pre" "stage"
 $ fit    :List of 2
  ..$ actions:List of 1
  .. ..$ model:List of 2
  .. .. ..$ spec   :List of 5
  .. .. .. ..$ args    :List of 3
  .. .. .. .. ..$ mtry : language ~9L
  .. .. .. .. .. ..- attr(*, ".Environment")=<environment: R_EmptyEnv> 
  .. .. .. .. ..$ trees: language ~NULL
  .. .. .. .. .. ..- attr(*, ".Environment")=<environment: R_EmptyEnv> 
  .. .. .. .. ..$ min_n: language ~19L
  .. .. .. .. .. ..- attr(*, ".Environment")=<environment: R_EmptyEnv> 
  .. .. .. ..$ eng_args: Named list()
  .. .. .. .. ..- attr(*, "class")= chr "quosures"
  .. .. .. ..$ mode    : chr "classification"
  .. .. .. ..$ method  : NULL
  .. .. .. ..$ engine  : chr "ranger"
  .. .. .. ..- attr(*, "class")= chr [1:2] "rand_forest" "model_spec"
  .. .. ..$ formula: NULL
  .. .. ..- attr(*, "class")= chr [1:3] "action_model" "action_fit" "action"
  ..$ fit    : NULL
  ..- attr(*, "class")= chr [1:2] "stage_fit" "stage"
 $ post   :List of 1
  ..$ actions: list()
  ..- attr(*, "class")= chr [1:2] "stage_post" "stage"
 $ trained: logi FALSE
 - attr(*, "class")= chr "workflow"

Great reprex even for someone like me who hasn't done this, so I can ask a naive question.

Is the model that fit_split() function model argument in wf_final here?

wf_final$fit$actions[1][[1]][1]
$spec
Random Forest Model Specification (classification)

Main Arguments:
  mtry = 3
  min_n = 13

Computational engine: ranger 

Thanks. Credit for the reprex goes to @apreshill (here) and @Max (here).

Looks like it. That the model I defined here with tune() placeholders:

tune_spec <-
    rand_forest(mtry = tune(),
                min_n = tune()) %>% 
    set_engine("ranger") %>% 
    set_mode("classification")

which becomes part of the workflow in this step...

tuneboth_wf <-
  workflow() %>% 
    add_recipe(tune_rec) %>% 
    add_model(tune_spec)

Again, I'm either playing or being dumb: how does the function find the model among all the objects in wf_final?

Hi Eric,

Did you try just straight up using tune::last_fit() here?

Hi, @apreshill. Yes, I should have mentioned that last_fit() does work in this case. But in my actual use case it did not, which I kind of expected based on your teaching notes. So I fell back to trying fit_split(), encountered the error described here, and realized the reprex also threw the same error. Have you come across it before?

Maybe the better course of action is for me to figure out why last_fit() did not work rather than troubleshoot fit_split()?

1 Like

Yes I would do that for sure- the warnings will be more useful (at least, one can hope!)

2 Likes

OK, sounds good. Though I will miss the rhyme of fit_split().

1 Like

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.