Compute confusion matrix using k-fold cross-validation in caret::train

Hi all,

I need help with the caret::train function. On my constant messing around with R, I have created a new variable called "age" in the Auto data frame in order to predict whether the car can be classified as "old" or "new" if the year of a given observation is below or above the median for the variable "year". So now I just want to perform LDA using 10-fold CV. I understand from the function trainControl that if classProbs=TRUE, the method will return class probabilities and assigned class, but I can't seem to find what I want. Ideally I want something similar to the argument "CV=TRUE" in the MASS::function, but instead of doing LOOCV, I want to use k-fold CV. Any ideas? Here's the code:

library(ISLR)
library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
Auto=Auto

#create vector and add as new column to the Auto data frame
age= rep("new", 392)
median(Auto$year)
#> [1] 76
age[Auto$year < 76]= "old"
Auto=cbind(Auto, age)

set.seed(123)
train_control= trainControl(method = "cv", number = 10, classProbs = TRUE)

#train the model
lda_auto_10cv= train(age ~ mpg + cylinders + displacement + acceleration + weight + horsepower, data= Auto, method= "lda", trControl=train_control)
print(lda_auto_10cv)
#> Linear Discriminant Analysis 
#> 
#> 392 samples
#>   6 predictor
#>   2 classes: 'new', 'old' 
#> 
#> No pre-processing
#> Resampling: Cross-Validated (10 fold) 
#> Summary of sample sizes: 352, 353, 353, 353, 353, 353, ... 
#> Resampling results:
#> 
#>   Accuracy   Kappa    
#>   0.7448077  0.4909725

Created on 2019-10-16 by the reprex package (v0.3.0)

Thanks!

Using

gives you 10-fold cross-validation. Is this what you are trying to do? You would use method = "LOOCV" to do leave-one-out.

What I mean is that using 10-fold CV using the caret::train function I get the model accuracy, the percentage of correctly classified observations, when I print the model, but no class predictions. If I used the MASS::lda function, it seems to me that I can use LOOCV passing the "CV = TRUE" argument (not K-fold CV), however one of the outputs would be the predicted classes that I can use to compute a confusion matrix. My question is if using 10-fold CV using the caret::train function I can obtain class predictions (and hence be able to check the confusion matrix) or if there is any function out there whose output I could use to compute the confusion matrix, but using K-fold CV.

I understand your question now. Here's some code but see the note below.

library(ISLR)
library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2

Auto = Auto

#create vector and add as new column to the Auto data frame
age = rep("new", 392)
median(Auto$year)
#> [1] 76

age[Auto$year < 76] = "old"
Auto = cbind(Auto, age)

set.seed(123)
train_control = trainControl(method = "cv",
                             number = 10,
                             classProbs = TRUE,
                             savePredictions = "final")

#train the model
lda_auto_10cv = train(
  age ~ mpg + cylinders + displacement + acceleration + weight + horsepower,
  data = Auto,
  method = "lda",
  trControl = train_control
)

# Get the average confusion matrix over resamples:
confusionMatrix(lda_auto_10cv)
#> Cross-Validated (10 fold) Confusion Matrix 
#> 
#> (entries are percentual average cell counts across resamples)
#>  
#>           Reference
#> Prediction  new  old
#>        new 38.5  9.9
#>        old 15.6 36.0
#>                             
#>  Accuracy (average) : 0.7449

# Get a confusion matrix by pooling the out-of-sample predictions
confusionMatrix(lda_auto_10cv$pred$pred, lda_auto_10cv$pred$obs)
#> Confusion Matrix and Statistics
#> 
#>           Reference
#> Prediction new old
#>        new 151  39
#>        old  61 141
#>                                           
#>                Accuracy : 0.7449          
#>                  95% CI : (0.6987, 0.7873)
#>     No Information Rate : 0.5408          
#>     P-Value [Acc > NIR] : < 2e-16         
#>                                           
#>                   Kappa : 0.4911          
#>                                           
#>  Mcnemar's Test P-Value : 0.03573         
#>                                           
#>             Sensitivity : 0.7123          
#>             Specificity : 0.7833          
#>          Pos Pred Value : 0.7947          
#>          Neg Pred Value : 0.6980          
#>              Prevalence : 0.5408          
#>          Detection Rate : 0.3852          
#>    Detection Prevalence : 0.4847          
#>       Balanced Accuracy : 0.7478          
#>                                           
#>        'Positive' Class : new             
#> 

Created on 2019-10-18 by the reprex package (v0.3.0)

Note that the pooled version of the confusion matrix will not give you the same answers as the resampled accuracy statistics. The former is compute by pooling while the resampling results average 10 different accuracy numbers.

Thanks, that's exactly what I wanted!

This topic was automatically closed 21 days after the last reply. New replies are no longer allowed.