VarImp For Rpart when using

Hi,

I am currently looking a CART trees in relation to variable importance
In the documentation for caret there is a function called varimp
Depending on the model it differs on how it calculates the variable importance
It also says for rpart that This method does not currently provide class--specific measures of importance when the response is a factor

When I create a rpart below I am able to use varimp. Can anyone tell me how this is calculated. Is it based on the drop in Gini Index when the variable is permutated or dropped?

Thanks for your time

library(rpart)
library(caret)
#> Warning: package 'caret' was built under R version 3.5.1
#> Loading required package: lattice
#> Loading required package: ggplot2
#> Warning: package 'ggplot2' was built under R version 3.5.1
library(tidyverse)
#> Warning: package 'tidyverse' was built under R version 3.5.1
#> Warning: package 'dplyr' was built under R version 3.5.1
# Get the Data
data(GermanCredit)

rf_mod <- rpart(Class~.,data =  GermanCredit)

caret::varImp(rf_mod) %>% 
  rownames_to_column() %>% 
  arrange(desc(Overall)) %>% 
  slice(1:10)
#>                             rowname  Overall
#> 1                            Amount 57.82419
#> 2                          Duration 47.32593
#> 3        CheckingAccountStatus.none 43.66521
#> 4        CheckingAccountStatus.lt.0 37.87057
#> 5            CreditHistory.Critical 24.11095
#> 6                    Purpose.NewCar 20.31030
#> 7                   Purpose.UsedCar 18.56253
#> 8      CheckingAccountStatus.gt.200 17.39552
#> 9  OtherDebtorsGuarantors.Guarantor 11.40171
#> 10                 Property.Unknown 11.25420

Created on 2018-11-21 by the reprex package (v0.2.1)

You can view the actual code using

caret:::getModelInfo("rpart", FALSE)[[1]]$varImp

?varImp has

Recursive Partitioning : The reduction in the loss function (e.g. mean squared error) attributed to each variable at each split is tabulated and the sum is returned. Also, since there may be candidate variables that are important but are not used in a split, the top competing variables are also tabulated at each split. This can be turned off using the maxcompete argument in rpart.control . This method does not currently provide class–specific measures of importance when the response is a factor.

Classic CART trees do not use the random forest permutation methods for measuring importance (since there are no out-of-bag samples).

Here's a simple example:

library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
library(rpart)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

ctrl <- rpart.control(cp = .1, maxcompete = 0)
two_split <- rpart(Species ~ ., data = iris, control = ctrl)
two_split
#> n= 150 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
#>   2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
#>   3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
#>     6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
#>     7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *

# Works off of the `improve` column here: 
two_split$splits
#>              count ncat    improve index       adj
#> Petal.Length   150   -1 50.0000000  2.45 0.0000000
#> Petal.Width      0   -1  1.0000000  0.80 1.0000000
#> Sepal.Length     0   -1  0.9200000  5.45 0.7600000
#> Sepal.Width      0    1  0.8333333  3.35 0.5000000
#> Petal.Width    100   -1 38.9694042  1.75 0.0000000
#> Petal.Length     0   -1  0.9100000  4.75 0.8043478
#> Sepal.Length     0   -1  0.7300000  6.15 0.4130435
#> Sepal.Width      0   -1  0.6700000  2.95 0.2826087

# the actual splits:
two_split$splits %>%
  as.data.frame() %>% 
  filter(count > 0)
#>   count ncat improve index adj
#> 1   150   -1 50.0000  2.45   0
#> 2   100   -1 38.9694  1.75   0

varImp(two_split, surrogates = FALSE, competes = FALSE)
#>              Overall
#> Petal.Length 50.0000
#> Petal.Width  38.9694
#> Sepal.Length  0.0000
#> Sepal.Width   0.0000

Created on 2018-11-21 by the reprex package (v0.2.1)

Hi @Max

Thank you for the quick reply, If i understand this correctly, the varimp is using the same attributes picked by the tree that result in the most gain? Is this the improve metric from your two_split$splits call?

What is the improve metric for classification. I had a look at the code but to be honest its a bit over my head but maybe from stepping through it, is the loss function the same one used when I build the tree for example information loss?

Thank you again for your time

Yes and yes, Note that, for regression trees, this would be RMSE.

Yes:

library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
library(rpart)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

ctrl <- rpart.control(cp = .1, maxcompete = 0)
args <- list(split = "information")
two_split <- rpart(Species ~ ., data = iris, control = ctrl, parms = args)
two_split
#> n= 150 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 150 100 setosa (0.33333333 0.33333333 0.33333333)  
#>   2) Petal.Length< 2.45 50   0 setosa (1.00000000 0.00000000 0.00000000) *
#>   3) Petal.Length>=2.45 100  50 versicolor (0.00000000 0.50000000 0.50000000)  
#>     6) Petal.Width< 1.75 54   5 versicolor (0.00000000 0.90740741 0.09259259) *
#>     7) Petal.Width>=1.75 46   1 virginica (0.00000000 0.02173913 0.97826087) *

# Works off of the `improve` column here: 
two_split$splits
#>              count ncat    improve index       adj
#> Petal.Length   150   -1 95.4771252  2.45 0.0000000
#> Petal.Width      0   -1  1.0000000  0.80 1.0000000
#> Sepal.Length     0   -1  0.9200000  5.45 0.7600000
#> Sepal.Width      0    1  0.8333333  3.35 0.5000000
#> Petal.Width    100   -1 47.8382715  1.75 0.0000000
#> Petal.Length     0   -1  0.9100000  4.75 0.8043478
#> Sepal.Length     0   -1  0.7300000  6.15 0.4130435
#> Sepal.Width      0   -1  0.6700000  2.95 0.2826087

# the actual splits:
two_split$splits %>%
  as.data.frame() %>% 
  filter(count > 0)
#>   count ncat  improve index adj
#> 1   150   -1 95.47713  2.45   0
#> 2   100   -1 47.83827  1.75   0

varImp(two_split, surrogates = FALSE, competes = FALSE)
#>               Overall
#> Petal.Length 95.47713
#> Petal.Width  47.83827
#> Sepal.Length  0.00000
#> Sepal.Width   0.00000

Created on 2018-11-21 by the reprex package (v0.2.1)

1 Like

Brilliant, thanks very much :slight_smile:

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.