Grid for bartMachine with caret?

After finishing DataCamp' s introduction to Caret I'm trying to learn how to use bartMachine. This is the code that I wrote based on chapter 5 from that course:

library(caret)
library(C50)
library(doParallel)
library(bartMachine)
set_bart_machine_num_cores(1) 
cl <- makePSOCKcluster(10)
registerDoParallel(cl)
data(churn)
set.seed(9782)
options(java.parameters = "-Xmx30000m")
# Create train/test indexes -----------------------------------------------
myFolds <- createFolds(churnTrain$churn, k=5)
# My control --------------------------------------------------------------
myControl <- trainControl(
  summaryFunction = twoClassSummary,
  classProbs = TRUE,
  verboseIter = TRUE,
  savePredictions = TRUE,
  index = myFolds
)

# BART --------------------------------------------------------------------
bartGrid <- expand.grid(num_trees = c(10, 15, 20, 100), k = 2, alpha = 0.95, beta = 2, nu = 3)
model_bart <- train(churn ~ .,
                    churnTrain,
                    metric = "ROC",
                    method = "bartMachine",
                    trControl = myControl,
                    preProc = c("center", "scale"),
                    tuneGrid = bartGrid,  
                    num_burn_in = 2000, 
                    num_iterations_after_burn_in = 2000, 
                    serialize = T)

# stop cluster ------------------------------------------------------------
stopCluster(cl)

plot(model_bart)

Alas, the AUC is < 0.5 while with ranger you can get >0.9. my guess is that my bartGrid is the problem. If that is the case, any suggestions on how to improve my code so I can get better results?

Thanks!

I'm surprised that a java-based model works with parallel processing. I'm also not sure why you are getting poor ROC values.

Try using random search instead of grid search via the search argument to trainControl.

1 Like

I tried this change but I run out of memory :frowning:

Iteration 200/4000
Exception in thread "pool-34-thread-1" java.lang.OutOfMemoryError: Java heap space
done building BART in 3.629 sec 

evaluating in sample data...Error in .jcall(bart_machine$java_bart_machine, "[[D", "getGibbsSamplesForPrediction",  : 
  java.lang.OutOfMemoryError: GC overhead limit exceeded
In addition: Warning message:
In nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo,  :
 
 Error in .jcall(bart_machine$java_bart_machine, "[[D", "getGibbsSamplesForPrediction",  : 
  java.lang.OutOfMemoryError: GC overhead limit exceeded Timing stopped at: 20.77 0.227 4.081

This is the code that is crashing:

library(caret)
library(C50)
library(doParallel)
library(bartMachine)
set_bart_machine_num_cores(1) 
cl <- makePSOCKcluster(2)
registerDoParallel(cl)
data(churn)
set.seed(9782)
options(java.parameters = "-Xmx30000m")
# Create train/test indexes -----------------------------------------------
myFolds <- createFolds(churnTrain$churn, k=5)
# My control --------------------------------------------------------------
myControl <- trainControl(
  summaryFunction = twoClassSummary,
  classProbs = TRUE,
  verboseIter = TRUE,
  savePredictions = TRUE,
  index = myFolds,
  search = "random"
)

# BART --------------------------------------------------------------------
# bartGrid <- expand.grid(num_trees = c(10, 15, 20, 100), k = 2, alpha = 0.95, beta = 2, nu = 3)
model_bart <- train(churn ~ .,
                    churnTrain,
                    metric = "ROC",
                    method = "bartMachine",
                    trControl = myControl,
                    preProc = c("center", "scale"),
                    # tuneGrid = bartGrid,
                    num_burn_in = 2000, 
                    num_iterations_after_burn_in = 2000, 
                    serialize = T)

# stop cluster ------------------------------------------------------------
stopCluster(cl)

plot(model_bart)

Any suggestions?

Give it more memory access in java.parameters? You may have the memory bandwidth but Java isn't allowed to use it.

If that's not the issue, run sequentially. Using M workers in parallel leads to a total memory overhead of M + 1 times the original process's memory usage.

I currently doing this:

options(java.parameters = "-Xmx30000m")

Is that the right way to give more memory access? If so, I will need a computer with more ram if i want to use this package.

I couldn't say. I'd ask the package maintainer.