Precision-recall curve in crossvalidation

Hello,

According to the documentation of the Caret package, the next chunks calculate the AUC metric in the context of cross-validation:

fitControl <- trainControl(method = "repeatedcv",
                           number = 10,
                           repeats = 10,
                           ## Estimate class probabilities
                           classProbs = TRUE,
                           ## Evaluate performance using 
                           ## the following function
                           summaryFunction = twoClassSummary)

set.seed(825)
gbmFit3 <- train(Class ~ ., data = training, 
                 method = "gbm", 
                 trControl = fitControl, 
                 verbose = FALSE, 
                 tuneGrid = gbmGrid,
                 ## Specify which metric to optimize
                 metric = "ROC")

But I need to calculate the precision-recall curve, a more sensitive measure of classification performance when there are imbalanced classes. Could someone tell me if the chunks below are the right way to do it?

fitControl <- trainControl(method = "repeatedcv",
                           number = 10,
                           repeats = 10,
                           ## Estimate class probabilities
                           classProbs = TRUE,
                           ## Evaluate performance using 
                           ## the following function
                           summaryFunction = prSummary)

set.seed(825)
gbmFit3 <- train(Class ~ ., data = training, 
                 method = "gbm", 
                 trControl = fitControl, 
                 verbose = FALSE, 
                 tuneGrid = gbmGrid,
                 ## Specify which metric to optimize
                 metric = "AUPRC")

Notice that only the summaryFunction and the metric arguments are changed.

I ask this because in Caret's documentation I didn't see any mention to the metric = "AUPRC" argument. Perhaps that argument is not necessary having summaryFunction = prSummary in the previous trainControl's chunk?

Thanks a lot!

prSummary() gives it to you in a column called AUC (see the examples in ?caret::prSummary) so use metric = "AUC".

2 Likes

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.