When I type in ?predict.rpart, I see the type = vector, prob, matrix, class.
However, I've also seen type = "raw" used. Why can't I see this in the help function?
When do you use raw vs. class?
When I type in ?predict.rpart, I see the type = vector, prob, matrix, class.
However, I've also seen type = "raw" used. Why can't I see this in the help function?
When do you use raw vs. class?
Can you give an example of code you have seen that passes "raw" to the type
parameter of predict.rpart()
? As far as I can tell, "raw" has never been an option:
type = c("vector", "tree", "class")
type = c("vector", "tree", "class", "matrix")
type = c("vector", "matrix", "class", "probs")
type = c("vector", "prob", "class", "matrix")
, and thus it has remained in the 18 years since.(this fun little spelunk made possible by METACRAN's read-only CRAN mirror! )
Grid <- expand.grid(cp=seq(0, 0.2, 0.001))
fitControl <- trainControl(method = "cv", number = 6)
tree3 <- train(factor(TERM_FLAG)~.,
data=tlFLAG.t,
method='rpart',
trControl = fitControl,
metric = "Accuracy",
tuneGrid = Grid,
na.action = na.omit,
parms=list(split='Gini'))
plot(tree3)
rpart.plot(tree3$finalModel, extra=4)
Training Set
pred3 <- predict(tree3, type = "prob")
confusionMatrix(pred3, factor(tlFLAG.t$TERM_FLAG))
Validation set
pred.v3 <- predict(tree3, type = "raw", newdata = tlFLAG.v)
confusionMatrix(pred.v3, factor(tlFLAG.v$TERM_FLAG))
What library()
statements go along with that code?
library(rpart.plot)
library(rpart)
library(caret)
It's going to be impossible for me to completely settle this without a reproducible example, but here's what I'm guessing is happening in that code:
tree3 <- train(factor(TERM_FLAG)~.,
data=tlFLAG.t,
method='rpart',
trControl = fitControl,
metric = "Accuracy",
tuneGrid = Grid,
na.action = na.omit,
parms=list(split='Gini'))
It looks like the tree3
object is being created by caret::train()
. caret::train()
returns a list with the class train
.
As an S3 generic function, predict()
uses the class of the object passed to it to determine which specific function should be called (you can use class()
to see an object's class(es)). Therefore, when tree3
is later passed to the generic predict()
function, the specific method that is running is predict.train()
, not predict.rpart()
.
If you consult the predict.train()
documentation, you'll see a brief mention of what type = "raw"
means:
type
either "raw" or "prob", for the number/class predictions or class probabilities, respectively. Class probabilities are not available for all classification models
There's more detail in the full caret
documentation — but confusingly, I think there's a mistake in the docs! The relevant section of the documentation says:
Also, there are very few standard syntaxes for model predictions in R. For example, to get class probabilities, many
predict
methods have an argument calledtype
that is used to specify whether the classes or probabilities should be generated. Different packages use different values oftype
, such as"prob"
,"posterior"
,"response"
,"probability"
or"raw"
. In other cases, completely different syntax is used.For
predict.train
, the type options are standardized to be"class"
and"prob"
(the underlying code matches these to the appropriate choices for each model.
But I'm pretty certain that last line ought to read "the type options are standardized to be "raw"
and "prob"
". The following line of code has been in predict.train()
since at least 2008:
if(!(type %in% c("raw", "prob"))) stop("type must be either \"raw\" or \"prob\"")
type = "raw"
mean in terms of rpart
?When you pass type = "raw"
to caret
's predict.train()
, it uses its own logic to map that onto the underlying model in order to generate predictions in terms of raw numbers or classes. To see exactly what prediction function is being called, use getModelInfo()
:
library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
rpart_info <- getModelInfo("rpart")
# This is what gets called for `type = "raw"`
rpart_info$rpart$predict
#> function(modelFit, newdata, submodels = NULL) {
#> if(!is.data.frame(newdata)) newdata <- as.data.frame(newdata)
#>
#> pType <- if(modelFit$problemType == "Classification") "class" else "vector"
#> out <- predict(modelFit, newdata, type=pType)
#>
#> if(!is.null(submodels))
#> {
#> tmp <- vector(mode = "list", length = nrow(submodels) + 1)
#> tmp[[1]] <- out
#> for(j in seq(along = submodels$cp))
#> {
#> prunedFit <- rpart::prune.rpart(modelFit, cp = submodels$cp[j])
#> tmp[[j+1]] <- predict(prunedFit, newdata, type=pType)
#> }
#> out <- tmp
#> }
#> out
#> }
# This is what gets called for `type = "prob"`
rpart_info$rpart$prob
#> function(modelFit, newdata, submodels = NULL) {
#> if(!is.data.frame(newdata)) newdata <- as.data.frame(newdata)
#> out <- predict(modelFit, newdata, type = "prob")
#>
#> if(!is.null(submodels))
#> {
#> tmp <- vector(mode = "list", length = nrow(submodels) + 1)
#> tmp[[1]] <- out
#> for(j in seq(along = submodels$cp))
#> {
#> prunedFit <- rpart::prune.rpart(modelFit, cp = submodels$cp[j])
#> tmpProb <- predict(prunedFit, newdata, type = "prob")
#> tmp[[j+1]] <- as.data.frame(tmpProb[, modelFit$obsLevels, drop = FALSE])
#> }
#> out <- tmp
#> }
#> out
#> }
^{Created on 2018-11-20 by the reprex package (v0.2.1)}
The important lines for the type = "raw"
case:
pType <- if(modelFit$problemType == "Classification") "class" else "vector"
out <- predict(modelFit, newdata, type=pType)
And for the type = "prob"
case:
out <- predict(modelFit, newdata, type = "prob")
So passing type = "raw"
to predict.train()
gives you either type = "class"
or type = "vector"
in the underlying call to predict.rpart()
. Passing type = "prob"
to predict.train()
gives you type = "prob"
in the underlying call to predict.rpart()
.
This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.