Trying to use a ML model as an objective function in R

m really hoping somebody can help with this issue. I am training a machine learning regression model (GBM). I am then wrapping the model prediction as a function and passing this function to an optimisation approach as the objective to maximise. I want to identify the ideal values of the variables that will maximise a target.

I have tried three approaches to building the ML model: Caret, mlr and a standalone gbm model. Each approach give me a different error message - all of which I have tested and the error message appears to be incorrect in each case!

gbm attempt error: Error in eval(predvars, data, env) : numeric 'envir' arg not of length one.

But when I run length(as.numeric(train.predict[1])) it returns 1

mlr attempt error: Error in predict.WrappedModel(mod, newdata = TestOpt2) : Assertion on 'newdata' failed: Must be of type 'data.frame', not 'double'.

But when I run str(TestOpt2) it is a dataframe

Caret error: Error in eval(predvars, data, env) : object 'Variable1' not found

But when I compare the names of the data and new data they are identical identical(names(TestOpt2),names(All))

You can cycle through these three options within the function Whatthefunc

I have also tried two different approaches to the optimisation: GA and L-BFGS-B. Each ml attempt yields identical error messages - which rules out the optimisation as the cause. There also o doesn't seem to be any issue with my data.

Any help would be greatly appreciated!

Here is my code:

library(caret)
library(gbm)
library(janitor)
library(stats)
library(GA)
library(mlr)

#Data
All <- read.table(header=TRUE, text="
Variable1   Variable2   Variable3   Variable4   Variable5   Variable6   Variable7   Variable8   Variable9   Variable10  Variable11  Variable12  Variable13  Variable14  Variable15  Variable16  Variable17  Variable18  Variable19  Variable20  target3
0.579683609 119.85  24.88748088 33.0005098  26.4253 137.31  137.31  -79.09  0.85    13.2164 137.31  137.31  0.85    0.324080855 0   0.577120258 60.9    60.9    0.577120258 3.2433  0.987754707
0.576727679 113.15  24.31826592 32.25783496 26.4142 140.524 140.524 71.95   0.94    22.6956 140.524 140.524 0.94    0.441132275 0   0.57708545  61.5    61.5    0.57708545  0.0775  0.989919769
0.577208332 113 24.33705575 32.30735243 23.8317 140.321 140.321 71.15   0.9 7.9864  140.321 140.321 0.9 0.39304627  0   0.57685004  61.15   61.15   0.57685004  1.2631  0.987590272
0.577576199 108 24.49857658 32.50055574 28.0172 139.453 139.453 71.35   0.92    14.53   139.453 139.453 0.92    0.408846073 0   0.577576199 61.9    61.9    0.577576199 0.5694  0.988920063
0.577929838 115.25  24.17832737 32.05690516 35.2397 141.358 141.358 71.15   0.9 7.3456  141.358 141.358 0.9 0.386234956 0   0.577218539 62.15   62.15   0.577218539 0.3667  0.988492239
0.576087888 113.6   24.2771692  32.18319765 25.8444 140.766 140.766 71.75   0.93    13.3131 140.766 140.766 0.93    0.428226521 0   0.576445041 62.55   62.55   0.576445041 0.6256  0.988558816
0.577389864 106 24.69213425 32.75662706 21.0117 138.372 138.372 72.95   0.93    7.5042  138.372 138.372 0.93    0.429499571 0   0.577026497 62.55   62.55   0.577026497 0.5569  0.988496447
0.575851043 115.45  24.03782559 31.86960866 27.3122 142.126 142.126 71.8    0.89    20.9964 142.126 142.126 0.89    0.375293585 0   0.57726591  62.5    62.5    0.57726591  0.2094  0.988383885
0.576019403 112.15  24.52640571 32.55041024 27.385  139.307 139.307 73.55   0.89    21.5844 139.307 139.307 0.89    0.375711904 0   0.57746306  62.1    62.15   0.57746306  0.2781  0.987028649
0.57649588  108.4   24.36288483 32.31185861 21.8114 140.32  140.32  71.05   0.93    5.3328  140.32  140.32  0.93    0.420996059 0   0.577570763 62.25   61.95   0.577570763 1.6214  0.989307139
0.57500685  108 24.4966443  32.47934951 36.7106 139.464 139.464 73.2    0.92    17.3794 139.464 139.464 0.92    0.413139718 0   0.577530391 62.65   62.65   0.577530391 0.3953  0.986261578
0.567781627 151.9   24.41461532 32.38556766 29.6844 139.908 139.908 73.55   0.83    9.5711  139.908 139.908 0.83    0.276703704 0   0.568859693 62.3    62.3    0.568859693 0.0992  0.970956127
0.576629102 113.9   24.15090337 32.02984125 23.3719 141.415 141.415 69.4    0.9 5.2 141.415 141.415 0.9 0.386433929 0   0.577340112 61.85   61.85   0.577340112 2.5317  0.994880167
0.575800096 110.55  24.09816885 31.96126249 23.2731 141.878 141.878 72.05   0.94    7.7528  141.878 141.878 0.94    0.437962412 0   0.57721745  61.95   61.95   0.57721745  0.3333  0.988927823
0.575890709 132.65  22.71560496 30.10521237 22.6253 161.673 161.673 59.15   0.86    4.9164  161.673 161.673 0.86    0.34153475  4.35    0.576515658 72.5    72.5    0.576515658 0.095   0.984784422
0.575351305 129.35  24.36934046 32.32798574 27.4639 140.25  140.25  73.1    0.82    9.6103  140.25  140.25  0.82    0.275308288 0   0.577143677 61.7    61.7    0.577143677 0.2069  0.987748092
0.579825484 107.85  24.14594847 32.01085497 26.385  142.41542   141.502 -82.39  0.94    6.8306  141.502 141.502 0.94    0.438421965 0   0.576983202 61.7    61.7    0.576983202 1.3839  0.990199171
0.577772778 125.65  22.49128589 29.82829069 20.7594 163.241 163.241 61.1    0.91    8.6214  163.241 163.241 0.91    0.407784465 4.2 0.577147821 72  72  0.577147821 0.1531  0.976647918
0.575124778 109 24.38346004 32.33536351 24.8222 140.218 140.218 72.05   0.94    6.4353  140.218 140.218 0.94    0.436363376 0   0.577276117 62.25   62.25   0.577276117 0.1172  0.989513477
0.577344578 108.85  24.65573487 32.69436185 22.2528 138.556 138.556 73.35   0.92    16.1    138.556 138.556 0.92    0.412596345 0   0.576981696 61.5    61.5    0.576981696 0.0614  0.9873569
0.567566544 125.5   22.64769026 30.03528773 39.7897 162.096 162.096 -89.94  0.9 9.5936  162.096 162.096 0.9 0.391566645 14  0.578041977 72.2    72.2    0.578041977 0.0811  0.975553976
0.579272909 119.35  22.94560366 30.45994974 24.8756 159.974 159.974 51.9622 0.94    7.9658  159.974 159.974 0.94    0.404247596 0   0.577246803 71.15   71.15   0.577246803 0.1319  0.989694175
0.575181912 112.65  24.54710731 32.55725383 28.7628 139.161 139.161 72.8    0.88    6.2003  139.161 139.161 0.88    0.369243665 0   0.576988388 61.4    61.4    0.576988388 1.6692  0.989271033
0.688430032 110.35  29.47957925 90.45190574 24.2514 139.075 139.075 72  0.9 4.9797  139.075 139.075 0.9 0.472997038 0   0.69320778  61.85   62.062  0.69320778  1.0011  0.988330148
0.578171976 128.1   22.78654642 30.2128267  27.8903 184.272 184.272 59.75   1.03    14.5669 184.272 184.272 1.03    0.390344215 1.75    0.576921875 72.05   72.05   0.576921875 1.0964  0.987736085
0.576523356 109.2   24.78297132 32.8715959  26.3419 137.885 137.885 72.4    0.92    10.7547 137.885 137.885 0.92    0.409145608 0   0.576888014 61.8    61.8    0.576888014 0.1108  0.988333695
0.576707604 106.25  24.86655355 32.99422521 28.9808 137.321 137.321 73.05   0.95    11.1872 137.321 137.321 0.95    0.443424069 0   0.577439931 61.15   61.15   0.577439931 2.9386  0.989248037
0.577991293 119.25  22.82035988 30.24592405 26.3369 160.944 160.944 60.1    0.9 14.3708 160.944 160.944 0.9 0.395336036 1.5 0.577052995 71.95   71.95   0.577052995 0.1508  0.988499666
0.577801414 123.716 22.82977797 30.27862429 28.7808 160.79  160.79  60.95   0.93    4.2119  160.79  160.79  0.93    0.420759557 0.95    0.577175749 71.8    71.8    0.577175749 0.1558  0.990938575
0.560632614 126.0295    22.16052103 29.41094807 29.9239 165.673 165.673 58.95   0.96    16.3506 165.673 165.673 0.96    0.414182883 0   0.560632614 73.35   73.35   0.560632614 0.0967  0.96969688
0.577082392 123.1725    22.93355652 30.42984023 25.6892 160.106 160.106 61.2    0.93    13.1706 160.106 160.106 0.93    0.427712507 0   0.577082392 70.9    70.9    0.577082392 0.0997  0.988728365
")

#Newdata to predict
TestOpt2 <- as.data.frame(All[13,])


#Caret attempt

gbmGrid <-  expand.grid(interaction.depth = c(1, 2,3,4, 5, 9), 
                        n.trees = (1:30)*50, 
                        shrinkage = 0.1,
                        n.minobsinnode = 10)

metric <- "MAE"
trainControl <- trainControl(method="cv", number=3)

set.seed(99)
gbm.caret <- caret::train(target3 ~ .,
                          data=All,
                          distribution="gaussian",
                          method="gbm",
                          trControl=trainControl,
                          verbose=FALSE,
                          tuneGrid=gbmGrid,
                          bag.fraction = 6,
                          metric=metric) 


caret.predict <- predict(gbm.caret, newdata=TestOpt2)
caret.predict


#mlr attempt
#Make regression task
regr.task = makeRegrTask(data = All, target = "target3")

#Learner
learner <- makeLearner(
  "regr.gbm",par.vals = list(n.trees = 500, interaction.depth = 3, bag.fraction = 1), predict.type = "response")

#Model
mod <- train(learner, regr.task)

task.pred <- predict(mod, newdata = TestOpt2)
Modobj <- task.pred$data$response

#GBM attempt

gbm.gbm <- gbm(target3 ~ .
               , data=All
               , distribution="gaussian"
               , n.trees=50
               , interaction.depth=3
               , n.minobsinnode=10
               , shrinkage=0.1
               , bag.fraction=6
               , cv.folds=5
               , verbose=FALSE
)


train.predict <- predict.gbm(object=gbm.gbm, newdata=TestOpt2)
train.predict[1]

#Optimisation Objective function
Whatthefunc <- function(TestOpt2){
  
   # train.predict <- predict.gbm(object=gbm.gbm, newdata=TestOpt2,n.trees = 10)#GBM attempt
   #  return(as.numeric(train.predict[1]))
  
    # task.pred <-  predict(mod, newdata = TestOpt2)#mlr attempt
    #  return(task.pred$data$response)
  
   caret.predict <- stats::predict(gbm.caret, newdata=TestOpt2)#Caret attempt
   return(caret.predict)
}




#Max Constraints
Maxcon <-  sapply(All, max, na.rm = T)

#Min Constraints
Mincon <-  sapply(All, min, na.rm = T)

#Starting point
Mid1 <- sapply(All, mean, na.rm = T)

#Run Genetic algorithm
GA <- ga("real-valued", fitness = Whatthefunc, 
         lower = Mincon, upper = Maxcon,
         selection = GA:::gareal_lsSelection_R,
         maxiter = 1000, 
         popSize = 100,
         run = 10,
         pmutation = 0.1,
         elitism = 5, 
         pcrossover = 0.7,
         keepBest = TRUE,
         seed = 123)


#Run general purpose Optimisation

Opt <- optim(Maxcon, 
             Whatthefunc, 
             method = c( "L-BFGS-B"),
             lower = Mincon,
             upper = Maxcon,
             control=list(fnscale=-1)#Maximisation
)


#Caret error test
#Check are names the same in data and newdata
identical(names(TestOpt2),names(All))

#mlr error test
#Check new data is dataframe
str(TestOpt2)

#gbm error test
#Check length of prediction
length(as.numeric(train.predict[1]))

The purpose of the test set is to be fully held-out from model tuning. So I would take a step back and question if what you're trying to do makes sense. :slight_smile: If I understand the example correctly, you're trying to optimize the fit on the test set? Traditionally, this is the purpose of the validation set, not the test set. train should already be accomplishing this for you. The test set should only be used to test the quality of fit, and should not be used in any sort of fitting or optimization.

I agree. You really should use a validation set or more extensive resampling to avoid doing the wrong thing.

A few code things:

  • For caret, you don't need to specify the distribution (unless you want something non-standard).
  • bag.fraction = 6 is not a fraction :smiley:

I would point you towards using the tune and finetune packages in tidymodels. There is Bayesian optimization and simulated annealing search methods there. If you need to get started, take a look at the website and the book. mlr3 also has similar methods.

If you want to do something custom, this GH repo has examples (that use resampling) for GA, SA, BO, PSO, and Nelder-Mead.