Running into an issue with custom fit_split()
function @apreshill shows in this great tidymodels training resource . I get an error complaining that the model
argument is missing, but I think it is defined in the workflow.
library(modeldata)
data("stackoverflow")
library(tidyverse)
library(tidymodels)
set.seed(100) # Important!
# make smaller to save time
so_split <- initial_split(sample_n(stackoverflow, size = 300),
strata = Remote)
so_train <- training(so_split)
so_test <- testing(so_split)
# again, simpler so runs faster
so_folds <- vfold_cv(so_train, v = 2, strata = Remote)
# recipe
tune_rec <- recipe(Remote ~ .,
data = so_train) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_lincomb(all_predictors()) %>%
step_downsample(Remote, under_ratio = tune())
# model
tune_spec <-
rand_forest(mtry = tune(),
min_n = tune()) %>%
set_engine("ranger") %>%
set_mode("classification")
# workflow
tuneboth_wf <-
workflow() %>%
add_recipe(tune_rec) %>%
add_model(tune_spec)
# tuning parameters
tuneboth_param <- parameters(tuneboth_wf)
tuneboth_param <-
tuneboth_param %>%
# Pick an upper bound for mtry:
update(mtry = mtry(c(1, 20)))
# fit
results <-
tuneboth_wf %>%
tune_grid(resamples = so_folds,
param_info = tuneboth_param)
# get best tuning parameters
best <-
results %>%
select_best(metric = "roc_auc")
# define final workflow
wf_final <-
tuneboth_wf %>%
finalize_workflow(best)
# re-run with best
fit_split <- function(formula, model, split, ...) {
wf <- workflows::add_model(
workflows::add_formula(workflows::workflow(),
formula,
blueprint =
hardhat::default_formula_blueprint(
indicators = FALSE,
allow_novel_levels = TRUE)),
model)
tune::last_fit(wf, split, ...)
}
results_best_test <-
wf_final %>%
fit_split(split = so_split,
metrics = metric_set(roc_auc, sens, spec))
What kind of object is wf_final
? Is it possible to post a relevant sample of it?
It's a workflow.
> wf_final
══ Workflow ═════════════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ─────────────────────────────────────────────────────────────────────────────────────
3 Recipe Steps
● step_dummy()
● step_lincomb()
● step_downsample()
── Model ────────────────────────────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)
Main Arguments:
mtry = 9
min_n = 19
Computational engine: ranger
What about str(wf_final)
, @ericpgreen ?
> str(wf_final)
List of 4
$ pre :List of 2
..$ actions:List of 1
.. ..$ recipe:List of 2
.. .. ..$ recipe :List of 6
.. .. .. ..$ var_info :Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 21 obs. of 4 variables:
.. .. .. .. ..$ variable: chr [1:21] "Country" "Salary" "YearsCodedJob" "OpenSource" ...
.. .. .. .. ..$ type : chr [1:21] "nominal" "numeric" "numeric" "numeric" ...
.. .. .. .. ..$ role : chr [1:21] "predictor" "predictor" "predictor" "predictor" ...
.. .. .. .. ..$ source : chr [1:21] "original" "original" "original" "original" ...
.. .. .. ..$ term_info:Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 21 obs. of 4 variables:
.. .. .. .. ..$ variable: chr [1:21] "Country" "Salary" "YearsCodedJob" "OpenSource" ...
.. .. .. .. ..$ type : chr [1:21] "nominal" "numeric" "numeric" "numeric" ...
.. .. .. .. ..$ role : chr [1:21] "predictor" "predictor" "predictor" "predictor" ...
.. .. .. .. ..$ source : chr [1:21] "original" "original" "original" "original" ...
.. .. .. ..$ steps :List of 3
.. .. .. .. ..$ :List of 9
.. .. .. .. .. ..$ terms :List of 2
.. .. .. .. .. .. ..$ : language ~all_nominal()
.. .. .. .. .. .. .. ..- attr(*, ".Environment")=<environment: 0x7fc8986e40d8>
.. .. .. .. .. .. ..$ : language ~-all_outcomes()
.. .. .. .. .. .. .. ..- attr(*, ".Environment")=<environment: 0x7fc8986e40d8>
.. .. .. .. .. .. ..- attr(*, "class")= chr "quosures"
.. .. .. .. .. ..$ role : chr "predictor"
.. .. .. .. .. ..$ trained : logi FALSE
.. .. .. .. .. ..$ one_hot : logi FALSE
.. .. .. .. .. ..$ preserve: logi FALSE
.. .. .. .. .. ..$ naming :function (var, lvl, ordinal = FALSE, sep = "_")
.. .. .. .. .. ..$ levels : NULL
.. .. .. .. .. ..$ skip : logi FALSE
.. .. .. .. .. ..$ id : chr "dummy_46eru"
.. .. .. .. .. ..- attr(*, "class")= chr [1:2] "step_dummy" "step"
.. .. .. .. ..$ :List of 7
.. .. .. .. .. ..$ terms :List of 1
.. .. .. .. .. .. ..$ : language ~all_predictors()
.. .. .. .. .. .. .. ..- attr(*, ".Environment")=<environment: 0x7fc8986fec00>
.. .. .. .. .. .. ..- attr(*, "class")= chr "quosures"
.. .. .. .. .. ..$ role : logi NA
.. .. .. .. .. ..$ trained : logi FALSE
.. .. .. .. .. ..$ max_steps: num 5
.. .. .. .. .. ..$ removals : NULL
.. .. .. .. .. ..$ skip : logi FALSE
.. .. .. .. .. ..$ id : chr "lincomp_n6IDq"
.. .. .. .. .. ..- attr(*, "class")= chr [1:2] "step_lincomb" "step"
.. .. .. .. ..$ :List of 11
.. .. .. .. .. ..$ terms :List of 1
.. .. .. .. .. .. ..$ : language ~Remote
.. .. .. .. .. .. .. ..- attr(*, ".Environment")=<environment: 0x7fc898737028>
.. .. .. .. .. .. ..- attr(*, "class")= chr "quosures"
.. .. .. .. .. ..$ under_ratio: num 0.991
.. .. .. .. .. ..$ ratio : logi NA
.. .. .. .. .. ..$ role : logi NA
.. .. .. .. .. ..$ trained : logi FALSE
.. .. .. .. .. ..$ column : NULL
.. .. .. .. .. ..$ target : logi NA
.. .. .. .. .. ..$ skip : logi TRUE
.. .. .. .. .. ..$ id : chr "downsample_a8xcH"
.. .. .. .. .. ..$ seed : int 24391
.. .. .. .. .. ..$ id : chr "downsample_a8xcH"
.. .. .. .. .. ..- attr(*, "class")= chr [1:2] "step_downsample" "step"
.. .. .. ..$ template :Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 225 obs. of 21 variables:
.. .. .. .. ..$ Country : Factor w/ 5 levels "Canada","Germany",..: 4 5 3 3 4 4 5 1 5 2 ...
.. .. .. .. ..$ Salary : num [1:225] 50000 150000 17619 12480 87500 ...
.. .. .. .. ..$ YearsCodedJob : int [1:225] 7 20 2 3 10 3 3 4 4 5 ...
.. .. .. .. ..$ OpenSource : num [1:225] 0 0 0 1 1 0 0 1 1 0 ...
.. .. .. .. ..$ Hobby : num [1:225] 0 0 1 1 1 1 1 1 1 1 ...
.. .. .. .. ..$ CompanySizeNumber : num [1:225] 1000 20 1 10 100 100 100 20 10 100 ...
.. .. .. .. ..$ CareerSatisfaction : int [1:225] 7 7 8 6 5 8 10 10 9 10 ...
.. .. .. .. ..$ Data_scientist : num [1:225] 0 0 0 0 0 1 0 0 0 0 ...
.. .. .. .. ..$ Database_administrator : num [1:225] 1 0 0 0 0 0 0 1 1 0 ...
.. .. .. .. ..$ Desktop_applications_developer : num [1:225] 1 0 0 0 0 0 0 0 0 0 ...
.. .. .. .. ..$ Developer_with_stats_math_background: num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
.. .. .. .. ..$ DevOps : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
.. .. .. .. ..$ Embedded_developer : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
.. .. .. .. ..$ Graphic_designer : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
.. .. .. .. ..$ Graphics_programming : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
.. .. .. .. ..$ Machine_learning_specialist : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
.. .. .. .. ..$ Mobile_developer : num [1:225] 0 1 1 0 0 0 0 0 1 1 ...
.. .. .. .. ..$ Quality_assurance_engineer : num [1:225] 0 0 0 0 0 0 0 0 0 0 ...
.. .. .. .. ..$ Systems_administrator : num [1:225] 0 0 0 0 0 0 0 0 1 0 ...
.. .. .. .. ..$ Web_developer : num [1:225] 0 0 0 1 1 0 1 1 1 0 ...
.. .. .. .. ..$ Remote : Factor w/ 2 levels "Remote","Not remote": 2 1 1 2 2 2 2 2 2 2 ...
.. .. .. ..$ levels : NULL
.. .. .. ..$ retained : logi NA
.. .. .. ..- attr(*, "class")= chr "recipe"
.. .. ..$ blueprint:List of 8
.. .. .. ..$ mold :List of 2
.. .. .. .. ..$ clean :function (blueprint, data)
.. .. .. .. ..$ process:function (blueprint, data)
.. .. .. ..$ forge :List of 2
.. .. .. .. ..$ clean :function (blueprint, new_data, outcomes)
.. .. .. .. ..$ process:function (blueprint, predictors, outcomes, extras)
.. .. .. ..$ intercept : logi FALSE
.. .. .. ..$ allow_novel_levels: logi FALSE
.. .. .. ..$ ptypes : NULL
.. .. .. ..$ fresh : logi FALSE
.. .. .. ..$ recipe : NULL
.. .. .. ..$ extra_role_ptypes : NULL
.. .. .. ..- attr(*, "class")= chr [1:3] "default_recipe_blueprint" "recipe_blueprint" "hardhat_blueprint"
.. .. ..- attr(*, "class")= chr [1:3] "action_recipe" "action_pre" "action"
..$ mold : NULL
..- attr(*, "class")= chr [1:2] "stage_pre" "stage"
$ fit :List of 2
..$ actions:List of 1
.. ..$ model:List of 2
.. .. ..$ spec :List of 5
.. .. .. ..$ args :List of 3
.. .. .. .. ..$ mtry : language ~9L
.. .. .. .. .. ..- attr(*, ".Environment")=<environment: R_EmptyEnv>
.. .. .. .. ..$ trees: language ~NULL
.. .. .. .. .. ..- attr(*, ".Environment")=<environment: R_EmptyEnv>
.. .. .. .. ..$ min_n: language ~19L
.. .. .. .. .. ..- attr(*, ".Environment")=<environment: R_EmptyEnv>
.. .. .. ..$ eng_args: Named list()
.. .. .. .. ..- attr(*, "class")= chr "quosures"
.. .. .. ..$ mode : chr "classification"
.. .. .. ..$ method : NULL
.. .. .. ..$ engine : chr "ranger"
.. .. .. ..- attr(*, "class")= chr [1:2] "rand_forest" "model_spec"
.. .. ..$ formula: NULL
.. .. ..- attr(*, "class")= chr [1:3] "action_model" "action_fit" "action"
..$ fit : NULL
..- attr(*, "class")= chr [1:2] "stage_fit" "stage"
$ post :List of 1
..$ actions: list()
..- attr(*, "class")= chr [1:2] "stage_post" "stage"
$ trained: logi FALSE
- attr(*, "class")= chr "workflow"
Great reprex
even for someone like me who hasn't done this, so I can ask a naive question.
Is the model that fit_split()
function model
argument in wf_final
here?
wf_final$fit$actions[1][[1]][1]
$spec
Random Forest Model Specification (classification)
Main Arguments:
mtry = 3
min_n = 13
Computational engine: ranger
Thanks. Credit for the reprex goes to @apreshill (here ) and @Max (here ).
Looks like it. That the model I defined here with tune()
placeholders:
tune_spec <-
rand_forest(mtry = tune(),
min_n = tune()) %>%
set_engine("ranger") %>%
set_mode("classification")
which becomes part of the workflow in this step...
tuneboth_wf <-
workflow() %>%
add_recipe(tune_rec) %>%
add_model(tune_spec)
Again, I'm either playing or being dumb: how does the function find the model among all the objects in wf_final?
Hi Eric,
Did you try just straight up using tune::last_fit()
here?
Hi, @apreshill . Yes, I should have mentioned that last_fit()
does work in this case. But in my actual use case it did not, which I kind of expected based on your teaching notes. So I fell back to trying fit_split()
, encountered the error described here, and realized the reprex also threw the same error. Have you come across it before?
Maybe the better course of action is for me to figure out why last_fit()
did not work rather than troubleshoot fit_split()
?
1 Like
Yes I would do that for sure- the warnings will be more useful (at least, one can hope!)
2 Likes
OK, sounds good. Though I will miss the rhyme of fit_split()
.
1 Like
system
Closed
March 24, 2020, 12:33pm
13
This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.