Factor covariates with caret::extractPrediction()

I'm having trouble using caret::extractPrediction() for a GLM which has factor covariates. See reprex below. It appears that caret does not use the model matrix which is constructed by glm(). I can construct the model matrix myself so that I'm not passing factors into caret::train(), but this feels a bit hacky.

This question also appears on StackOverflow here: https://stackoverflow.com/questions/29490751/does-extractprediction-support-factors, but there are no answers.

Reprex:

library(tidyverse)
#> ── Attaching packages ───────────────── tidyverse 1.2.1 ──
#> ✔ ggplot2 3.0.0     ✔ purrr   0.2.5
#> ✔ tibble  1.4.2     ✔ dplyr   0.7.6
#> ✔ tidyr   0.8.1     ✔ stringr 1.3.1
#> ✔ readr   1.1.1     ✔ forcats 0.3.0
#> ── Conflicts ──────────────────── tidyverse_conflicts() ──
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ dplyr::lag()    masks stats::lag()
library(caret)
#> Loading required package: lattice
#> 
#> Attaching package: 'caret'
#> The following object is masked from 'package:purrr':
#> 
#>     lift

data(mtcars)
mtcars2 <- mtcars %>% 
  mutate(cyl = as.factor(cyl))

train_glm <- caret::train(
    mpg ~ .
  , data = mtcars2
  , method = 'glm'
  , trControl = trainControl(
    method = 'cv'
    , number = 5
  )
)

extractPrediction(
  list(train_glm)
)
#> Error in eval(predvars, data, env): object 'cyl6' not found

The answer is that extractPrediction() does not play nice with an object of class train.formula. I spent a bit of time with the code for caret and it seems that predict.train() and extractPrediction() are slightly different in how they treat the data used for prediction. predict.train() will call model.frame() to adjust the data for objects of class train.formula. extractPrediction() doesn't. It may be possible to change this behavior. Meet me over on the caret issues list, if you've read this far.

Answer reprex:

suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(caret))

data(mtcars)
mtcars2 <- mtcars %>% 
  mutate(cyl = as.factor(cyl))

train_glm_formula <- caret::train(
    mpg ~ .
  , data = mtcars2
  , method = 'glm'
  , trControl = trainControl(
    method = 'cv'
    , number = 5
    , returnData = TRUE
  )
)

# This works
predict(train_glm_formula) %>% 
  head()
#>        1        2        3        4        5        6 
#> 21.80150 21.19320 26.27339 19.64456 17.70480 18.97976

# This doesn't
extractPrediction(
  list(train_glm_formula)
) %>% 
  head()
#> Error in eval(predvars, data, env): object 'cyl6' not found

train_glm_default <- caret::train(
    x = mtcars2 %>% select(-mpg)
  , y = mtcars2$mpg
  , method = 'glm'
  , trControl = trainControl(
    method = 'cv'
    , number = 5
    , returnData = TRUE
  )
)

# This works
predict(train_glm_default) %>% 
  head()
#>        1        2        3        4        5        6 
#> 21.80150 21.19320 26.27339 19.64456 17.70480 18.97976

# And so does this
extractPrediction(
  list(train_glm_default)
)  %>% 
  head()
#>    obs     pred model dataType  object
#> 1 21.0 21.80150   glm Training Object1
#> 2 21.0 21.19320   glm Training Object1
#> 3 22.8 26.27339   glm Training Object1
#> 4 21.4 19.64456   glm Training Object1
#> 5 18.7 17.70480   glm Training Object1
#> 6 18.1 18.97976   glm Training Object1
1 Like

If you’ve posted this as an issue, can you please include the link so others might follow the breadcrumbs?

Have not yet posted the issue. I have a workaround in place, so need to wrap something up before I can post something on GitHub.

1 Like