I'm having trouble using caret::extractPrediction() for a GLM which has factor covariates. See reprex below. It appears that caret does not use the model matrix which is constructed by glm(). I can construct the model matrix myself so that I'm not passing factors into caret::train(), but this feels a bit hacky.
The answer is that extractPrediction() does not play nice with an object of class train.formula. I spent a bit of time with the code for caret and it seems that predict.train() and extractPrediction() are slightly different in how they treat the data used for prediction. predict.train() will call model.frame() to adjust the data for objects of class train.formula. extractPrediction() doesn't. It may be possible to change this behavior. Meet me over on the caret issues list, if you've read this far.
Answer reprex:
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(caret))
data(mtcars)
mtcars2 <- mtcars %>%
mutate(cyl = as.factor(cyl))
train_glm_formula <- caret::train(
mpg ~ .
, data = mtcars2
, method = 'glm'
, trControl = trainControl(
method = 'cv'
, number = 5
, returnData = TRUE
)
)
# This works
predict(train_glm_formula) %>%
head()
#> 1 2 3 4 5 6
#> 21.80150 21.19320 26.27339 19.64456 17.70480 18.97976
# This doesn't
extractPrediction(
list(train_glm_formula)
) %>%
head()
#> Error in eval(predvars, data, env): object 'cyl6' not found
train_glm_default <- caret::train(
x = mtcars2 %>% select(-mpg)
, y = mtcars2$mpg
, method = 'glm'
, trControl = trainControl(
method = 'cv'
, number = 5
, returnData = TRUE
)
)
# This works
predict(train_glm_default) %>%
head()
#> 1 2 3 4 5 6
#> 21.80150 21.19320 26.27339 19.64456 17.70480 18.97976
# And so does this
extractPrediction(
list(train_glm_default)
) %>%
head()
#> obs pred model dataType object
#> 1 21.0 21.80150 glm Training Object1
#> 2 21.0 21.19320 glm Training Object1
#> 3 22.8 26.27339 glm Training Object1
#> 4 21.4 19.64456 glm Training Object1
#> 5 18.7 17.70480 glm Training Object1
#> 6 18.1 18.97976 glm Training Object1