Hi everyone,
Is it possible to use the composition = dgCMatrix
option when I pass a recipe in the tune_grid function?
Currently, you can set the composition = dgCMatrix
option in thebake ()
function.
Thanks!
Here a reprex:
# Packages ----------------------------------------------------------------
library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'xts':
#> method from
#> as.zoo.xts zoo
#> ── Attaching packages ────────────────────────────────────────────── tidymodels 0.0.3 ──
#> ✓ broom 0.5.2 ✓ recipes 0.1.7.9002
#> ✓ dials 0.0.4 ✓ rsample 0.0.5
#> ✓ infer 0.5.1 ✓ yardstick 0.0.4
#> ✓ parsnip 0.0.4.9000
#> ── Conflicts ───────────────────────────────────────────────── tidymodels_conflicts() ──
#> x scales::discard() masks purrr::discard()
#> x dplyr::filter() masks stats::filter()
#> x recipes::fixed() masks stringr::fixed()
#> x dplyr::lag() masks stats::lag()
#> x dials::margin() masks ggplot2::margin()
#> x yardstick::spec() masks readr::spec()
#> x recipes::step() masks stats::step()
#> x recipes::yj_trans() masks scales::yj_trans()
library(textrecipes)
library(workflows)
library(tune)
# Data --------------------------------------------------------------------
train <- read_csv('train.csv.zip') %>%
sample_frac(0.1)
#> Parsed with column specification:
#> cols(
#> id = col_character(),
#> comment_text = col_character(),
#> toxic = col_double(),
#> severe_toxic = col_double(),
#> obscene = col_double(),
#> threat = col_double(),
#> insult = col_double(),
#> identity_hate = col_double()
#> )
train <- train %>%
mutate(toxic = ifelse(toxic == 1, "yes", "no"))
# Recipe -----------------------------------------------------------------
rec <- recipe(toxic ~ comment_text, data = train) %>%
step_tokenize(comment_text) %>%
step_stopwords(comment_text) %>%
step_tokenfilter(comment_text, max_tokens = 1000) %>%
step_tfidf(comment_text)
# Model ------------------------------------------------------------------
model <- logistic_reg(
mode = "classification",
penalty = tune(),
mixture = tune()
) %>%
set_engine("glmnet")
# Workflow ---------------------------------------------------------------
toxic_wflow <- workflow() %>%
add_recipe(rec) %>%
add_model(model)
# Grid --------------------------------------------------------------------
param_grid <- grid_regular(
mixture(),
penalty(),
levels = 2
)
# CV ----------------------------------------------------------------------
cv_rs <- vfold_cv(
train %>% select(toxic, comment_text),
v = 10,
strata = toxic
)
# Tune --------------------------------------------------------------------
grid_search <- tune_grid(
toxic_wflow,
resamples = cv_rs,
metrics = metric_set(roc_auc),
grid = param_grid,
)
show_best(grid_search, n = 1)
#> # A tibble: 1 x 7
#> penalty mixture .metric .estimator mean n std_err
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl>
#> 1 1 0 roc_auc binary 0.893 10 0.00423
Created on 2019-12-12 by the reprex package (v0.3.0)