Problems using tuning_grid with tabnet

I am trying to tune tabnet using tune grid. But I immediately run into trouble. Here's my toy example:

library(tidyverse)
library(tidymodels)
library(tabnet)

# Draw a random sample of 2000 to try the models

set.seed(1234)

 diamonds <- diamonds %>%    
  sample_n(2000)

diamonds_split <- initial_split(diamonds, prop = 0.80, strata="price")

diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split)

folds <- vfold_cv(diamonds_train, v = 5, strata="price")
metric <- metric_set(rmse,rsq,mae)

#  Model Tabnet

tabnet_spec <- tabnet(epochs = tune("ep"), batch_size = tune("bs") ) %>%
  set_engine("torch", verbose = TRUE) %>%
  set_mode("regression")

tabnet_rec <-
  recipe(price ~ ., data = diamonds_train) %>%
  step_normalize(all_numeric_predictors()) 

tabnet_wflow <- 
  workflow() %>%
  add_model(tabnet_spec) %>%
  add_recipe(tabnet_rec)

tabnet_grid <- expand_grid(ep=c(5,10),bs=c(64,128))

tabnet_res <-
  tune_grid(
    tabnet_wflow,
    resamples = folds,
    metrics = metric,
    grid = tabnet_grid
  )
  
collect_metrics(tabnet_res)

Here are the errors the program throws:

There were issues with some computations   A: x10
Warning message:
All models failed. Run `show_notes(.Last.tune.result)` for more information. 
>   
> collect_metrics(tabnet_res)
Error in `estimate_tune_results()`:
! All models failed. Run `show_notes(.Last.tune.result)` for more information.
Run `rlang::last_trace()` to see where the error occurred.
> show_notes(.Last.tune.result)
unique notes:
───────────────────────────────────────────────────────────
attempt to select less than one element in get1index <real>

If I try the 4 combinations one by one I don't get any errors. Am I missing something obvious?

This topic was automatically closed 42 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.