predict.rpart - type = raw vs. class

When I type in ?predict.rpart, I see the type = vector, prob, matrix, class.

However, I've also seen type = "raw" used. Why can't I see this in the help function?

When do you use raw vs. class?

Can you give an example of code you have seen that passes "raw" to the type parameter of predict.rpart()? As far as I can tell, "raw" has never been an option:

  • The original S implementation had type = c("vector", "tree", "class")
  • Version 3.0-0 expanded that to type = c("vector", "tree", "class", "matrix")
  • Version 3.0-2 had type = c("vector", "matrix", "class", "probs")
  • But then Version 3.1-0 went to type = c("vector", "prob", "class", "matrix"), and thus it has remained in the 18 years since.

(this fun little spelunk made possible by METACRAN's read-only CRAN mirror! :grin:)

2 Likes
Grid <- expand.grid(cp=seq(0, 0.2, 0.001))
fitControl <- trainControl(method = "cv", number = 6)
tree3 <- train(factor(TERM_FLAG)~.,
                data=tlFLAG.t,
                method='rpart',
                trControl = fitControl,
                metric = "Accuracy",
                tuneGrid = Grid,
                na.action = na.omit,
                parms=list(split='Gini'))

plot(tree3)
rpart.plot(tree3$finalModel, extra=4)

Training Set

pred3 <- predict(tree3, type = "prob")
confusionMatrix(pred3, factor(tlFLAG.t$TERM_FLAG))

Validation set

pred.v3 <- predict(tree3, type = "raw", newdata = tlFLAG.v)
confusionMatrix(pred.v3, factor(tlFLAG.v$TERM_FLAG))

What library() statements go along with that code?

library(rpart.plot)
library(rpart)
library(caret)

It's going to be impossible for me to completely settle this without a reproducible example, but here's what I'm guessing is happening in that code:

tree3 <- train(factor(TERM_FLAG)~.,
                data=tlFLAG.t,
                method='rpart',
                trControl = fitControl,
                metric = "Accuracy",
                tuneGrid = Grid,
                na.action = na.omit,
                parms=list(split='Gini'))

It looks like the tree3 object is being created by caret::train(). caret::train() returns a list with the class train.

As an S3 generic function, predict() uses the class of the object passed to it to determine which specific function should be called (you can use class() to see an object's class(es)). Therefore, when tree3 is later passed to the generic predict() function, the specific method that is running is predict.train(), not predict.rpart().

If you consult the predict.train() documentation, you'll see a brief mention of what type = "raw" means:

type
either "raw" or "prob", for the number/class predictions or class probabilities, respectively. Class probabilities are not available for all classification models

There's more detail in the full caret documentation — but confusingly, I think there's a mistake in the docs! The relevant section of the documentation says:

Also, there are very few standard syntaxes for model predictions in R. For example, to get class probabilities, many predict methods have an argument called type that is used to specify whether the classes or probabilities should be generated. Different packages use different values of type , such as "prob" , "posterior" , "response" , "probability" or "raw" . In other cases, completely different syntax is used.

For predict.train , the type options are standardized to be "class" and "prob" (the underlying code matches these to the appropriate choices for each model.

But I'm pretty certain that last line ought to read "the type options are standardized to be "raw" and "prob"". The following line of code has been in predict.train() since at least 2008:

    if(!(type %in% c("raw", "prob"))) stop("type must be either \"raw\" or \"prob\"")

So what does type = "raw" mean in terms of rpart?

When you pass type = "raw" to caret's predict.train(), it uses its own logic to map that onto the underlying model in order to generate predictions in terms of raw numbers or classes. To see exactly what prediction function is being called, use getModelInfo():

library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2

rpart_info <- getModelInfo("rpart")

# This is what gets called for `type = "raw"`
rpart_info$rpart$predict
#> function(modelFit, newdata, submodels = NULL) {
#>                     if(!is.data.frame(newdata)) newdata <- as.data.frame(newdata)
#> 
#>                     pType <- if(modelFit$problemType == "Classification") "class" else "vector"
#>                     out  <- predict(modelFit, newdata, type=pType)
#> 
#>                     if(!is.null(submodels))
#>                     {
#>                       tmp <- vector(mode = "list", length = nrow(submodels) + 1)
#>                       tmp[[1]] <- out
#>                       for(j in seq(along = submodels$cp))
#>                       {
#>                         prunedFit <- rpart::prune.rpart(modelFit, cp = submodels$cp[j])
#>                         tmp[[j+1]]  <- predict(prunedFit, newdata, type=pType)
#>                       }
#>                       out <- tmp
#>                     }
#>                     out
#>                   }

# This is what gets called for `type = "prob"`
rpart_info$rpart$prob
#> function(modelFit, newdata, submodels = NULL) {
#>                     if(!is.data.frame(newdata)) newdata <- as.data.frame(newdata)
#>                     out <- predict(modelFit, newdata, type = "prob")
#> 
#>                     if(!is.null(submodels))
#>                     {
#>                       tmp <- vector(mode = "list", length = nrow(submodels) + 1)
#>                       tmp[[1]] <- out
#>                       for(j in seq(along = submodels$cp))
#>                       {
#>                         prunedFit <- rpart::prune.rpart(modelFit, cp = submodels$cp[j])
#>                         tmpProb <- predict(prunedFit, newdata, type = "prob")
#>                         tmp[[j+1]] <- as.data.frame(tmpProb[, modelFit$obsLevels, drop = FALSE])
#>                       }
#>                       out <- tmp
#>                     }
#>                     out
#>                   }

Created on 2018-11-20 by the reprex package (v0.2.1)

The important lines for the type = "raw" case:

pType <- if(modelFit$problemType == "Classification") "class" else "vector"
out  <- predict(modelFit, newdata, type=pType)

And for the type = "prob" case:

out <- predict(modelFit, newdata, type = "prob")

So passing type = "raw" to predict.train() gives you either type = "class" or type = "vector" in the underlying call to predict.rpart(). Passing type = "prob" to predict.train() gives you type = "prob" in the underlying call to predict.rpart().

2 Likes

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.