I am trying to tune tabnet using tune grid. But I immediately run into trouble. Here's my toy example:
library(tidyverse)
library(tidymodels)
library(tabnet)
# Draw a random sample of 2000 to try the models
set.seed(1234)
diamonds <- diamonds %>%
sample_n(2000)
diamonds_split <- initial_split(diamonds, prop = 0.80, strata="price")
diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split)
folds <- vfold_cv(diamonds_train, v = 5, strata="price")
metric <- metric_set(rmse,rsq,mae)
# Model Tabnet
tabnet_spec <- tabnet(epochs = tune("ep"), batch_size = tune("bs") ) %>%
set_engine("torch", verbose = TRUE) %>%
set_mode("regression")
tabnet_rec <-
recipe(price ~ ., data = diamonds_train) %>%
step_normalize(all_numeric_predictors())
tabnet_wflow <-
workflow() %>%
add_model(tabnet_spec) %>%
add_recipe(tabnet_rec)
tabnet_grid <- expand_grid(ep=c(5,10),bs=c(64,128))
tabnet_res <-
tune_grid(
tabnet_wflow,
resamples = folds,
metrics = metric,
grid = tabnet_grid
)
collect_metrics(tabnet_res)
Here are the errors the program throws:
There were issues with some computations A: x10
Warning message:
All models failed. Run `show_notes(.Last.tune.result)` for more information.
>
> collect_metrics(tabnet_res)
Error in `estimate_tune_results()`:
! All models failed. Run `show_notes(.Last.tune.result)` for more information.
Run `rlang::last_trace()` to see where the error occurred.
> show_notes(.Last.tune.result)
unique notes:
───────────────────────────────────────────────────────────
attempt to select less than one element in get1index <real>
If I try the 4 combinations one by one I don't get any errors. Am I missing something obvious?