Visualizing median, min and max in decision trees

Hi everyone! I’m pretty new to R, trying to find the answer to the specific question related to plotting decision trees.

Currently visualizing my decision tree with rpart.plot can see only average value in the node. Can I also get median, max and min values visualized as well? Was not able to find anything in the rpart documentation, may be someone faced the same issue as well?

I think these can be calculated fairly easily for the leaf nodes. The object returned by rpart() has a where element that identifies the leaf each observation was assigned to. You can use this to group the data and calculate any desired statistics. For intermediate nodes, I think you could use the information in the splits element of the rpart object to reconstruct each node. That may be a manual process, though I am not sure. Can you provide some data and an example of how you are using rpart()?

1 Like

Hey! Thanks for the idea, I found that where clause too, but didn't manage to understand how to use it.

Data sample looks like this (values made up by myself):
Band MedianRate Value
6 1.03 1000
6 1.11 1200
5 0.87 800

fit <- rpart(Value~Band+MedianRate, data=data_cleansed, method="anova")

I even managed to find how to add additional text to node lables using

> node.fun1 <- function(x, labs, digits, varlen)
+ {paste(labs,"\ndev",x$frame$dev)}
> prp(fit, extra=6, node.fun = node.fun1)

but this seems to be invloving only fit$frame columns, which isn't enough in my case. Ideally, I would want to see in each node:

  • average for node group
  • median for node group
  • min Value for node group
  • max Value for node group

Will it suffice?
Thanks in advance.

Below are some calculations on a data set I invented. Finding the mean, median, etc. of the leaf nodes is shown first. It is easy to match LeafStats$LeafNum to the rpart.plot by looking at the Mean values.

I started calculations for the other nodes and you can see that it is much uglier. From fit$splits I can get from the index column the exact value used to split each node. For example, the top node is split at MedianRate < 1.7635074. Competing splits that where not used are also shown in Splits, so you have to be careful. From the fit$frame element I can get the node numbers and the mean y values of each node.

My function GetStats() uses the NodeNum simply as a label in case you wanted to join the individual node data frames together.

If the tree has many nodes, these manual calculations would be very tedious. I do not see at the moment how to automate the calculations for the non-leaf nodes.

#Invent data
set.seed(1)
DF1 <- data.frame(Band = round(rnorm(30, 5, 0.7), 0),
                  MedianRate = runif(30, min = 0.25, max = 2.25),
                  Value = rnorm(30, 900, 50))
DF2 <- data.frame(Band = round(rnorm(30, 6, 0.7), 0),
                  MedianRate = runif(30, min = 0.5, max = 2.5),
                  Value = rnorm(30, 940, 50))
DF <- rbind(DF1, DF2)

library(rpart)
library(rpart.plot)
library(dplyr)

fit <- rpart(Value~Band+MedianRate, data=DF, method="anova")
rpart.plot(fit)


#Calculate statistics of the leaves
LeafStats <- DF %>% mutate(LeafNum = fit$where) %>% 
  group_by(LeafNum) %>% 
  summarise(Mean = mean(Value), Median = median(Value), 
            Min = min(Value), Max = max(Value), N = n())
LeafStats #See the print out of Frame below to understand LeafNum
#> # A tibble: 5 x 6
#>   LeafNum  Mean Median   Min   Max     N
#>     <int> <dbl>  <dbl> <dbl> <dbl> <int>
#> 1       3  879.   869.  810.  944.    15
#> 2       5  909.   920.  837.  999.    13
#> 3       6  942.   930.  894. 1007.    12
#> 4       8  916.   931.  863.  955.     9
#> 5       9  946.   929.  882. 1044.    11

###Node Calculations
Splits <- fit$splits
Splits
#>            count ncat      improve     index       adj
#> MedianRate    60   -1 0.0578866182 1.7635074 0.0000000
#> Band          60   -1 0.0012291083 5.5000000 0.0000000
#> MedianRate    40    1 0.2155986209 1.1872099 0.0000000
#> Band          40    1 0.0002392061 5.5000000 0.0000000
#> Band          25   -1 0.1266509674 5.5000000 0.0000000
#> MedianRate    25   -1 0.1038963104 0.7519308 0.0000000
#> MedianRate     0   -1 0.7200000000 0.9077468 0.4166667
#> MedianRate    20    1 0.1026266061 2.0550743 0.0000000
#> Band          20   -1 0.0111632486 5.5000000 0.0000000
#> Band           0    1 0.6500000000 6.5000000 0.2222222
Frame <- fit$frame
Frame
#>           var  n wt        dev     yval complexity ncompete nsurrogate
#> 1  MedianRate 60 60 142627.676 916.1906 0.09688132        1          0
#> 2  MedianRate 40 40  89887.778 907.8959 0.09688132        1          0
#> 4      <leaf> 15 15  18247.778 879.4796 0.01000000        0          0
#> 5        Band 25 25  52260.318 924.9457 0.04640628        1          1
#> 10     <leaf> 13 13  30606.923 909.3128 0.01000000        0          0
#> 11     <leaf> 12 12  15034.576 941.8814 0.01000000        0          0
#> 3  MedianRate 20 20  44483.665 932.7800 0.03200787        1          1
#> 6      <leaf>  9  9   8638.785 916.0772 0.01000000        0          0
#> 7      <leaf> 11 11  31279.672 946.4460 0.01000000        0          0

GetStats <- function(x, NodeNum) {
  summarise(x, Mean = mean(Value), Median = median(Value), 
            Min = min(Value), Max = max(Value), 
            N = n(), Perc = n()/nrow(DF)) %>% 
    mutate(Node = NodeNum)
}

Node2 <- DF %>% filter(MedianRate < Splits[1, "index"])  
Node2Stats <- GetStats(Node2, NodeNum = 2)
Node2Stats
#>       Mean   Median      Min      Max  N      Perc Node
#> 1 907.8959 912.0706 809.7521 1007.152 40 0.6666667    2

Node3 <- DF %>% filter(MedianRate >= Splits[1, "index"])  
Node3Stats <- GetStats(Node3, NodeNum = 3)
Node3Stats
#>     Mean   Median      Min      Max  N      Perc Node
#> 1 932.78 930.1466 862.8363 1044.358 20 0.3333333    3

Node4 <- Node2 %>% filter(MedianRate >= Splits[3, "index"])  
Node4Stats <- GetStats(Node4, NodeNum = 4)
Node4Stats
#>       Mean   Median      Min      Max  N Perc Node
#> 1 879.4796 869.3987 809.7521 944.0554 15 0.25    4

Node5 <- Node2 %>% filter(MedianRate < Splits[3, "index"])  
Node5Stats <- GetStats(Node5, NodeNum = 5)
Node5Stats
#>       Mean   Median      Min      Max  N      Perc Node
#> 1 924.9457 923.7865 837.3183 1007.152 25 0.4166667    5

Created on 2019-10-11 by the reprex package (v0.3.0.9000)

1 Like

I just realized that I was confused about the value of fit$where. It is the row number of fit.frame where the leaf appears. The node number is the row name in fit$frame. For example, in LeafStats, LeafNum = 3 has a mean value of 879. Looking in Frame, the Third Row has that mean y value but the Node Number is 4. Super confusing! Similarly, the eighth row of Frame has the data for Node Number 6.

This is great and super cool! Thanks a lot for the solution. Not only it taught me how to get data from trees, but also that pipe %>% operator!

For now I am cool with the first part of the solution, having only Min,Max and Median values for leaves. Yet, the final piece I'm struggling - putting these values together in the visual graph. I might crack this later on today (prp function with node.fun should work), but if you happen to see this solution easy, would be grateful to get!

I would use the prune function to shrink the tree and make some of the non-leaf nodes into leaves. Here is an example.

#Invent data
set.seed(1)
DF1 <- data.frame(Band = round(rnorm(30, 5, 0.7), 0),
                  MedianRate = runif(30, min = 0.25, max = 2.25),
                  Value = rnorm(30, 900, 50))
DF2 <- data.frame(Band = round(rnorm(30, 6, 0.7), 0),
                  MedianRate = runif(30, min = 0.5, max = 2.5),
                  Value = rnorm(30, 940, 50))
DF <- rbind(DF1, DF2)

library(rpart)
#> Warning: package 'rpart' was built under R version 3.5.3
library(rpart.plot)
#> Warning: package 'rpart.plot' was built under R version 3.5.3
library(dplyr)

fit <- rpart(Value~Band+MedianRate, data=DF, method="anova")
rpart.plot(fit)



fit$frame
#>           var  n wt        dev     yval complexity ncompete nsurrogate
#> 1  MedianRate 60 60 142627.676 916.1906 0.09688132        1          0
#> 2  MedianRate 40 40  89887.778 907.8959 0.09688132        1          0
#> 4      <leaf> 15 15  18247.778 879.4796 0.01000000        0          0
#> 5        Band 25 25  52260.318 924.9457 0.04640628        1          1
#> 10     <leaf> 13 13  30606.923 909.3128 0.01000000        0          0
#> 11     <leaf> 12 12  15034.576 941.8814 0.01000000        0          0
#> 3  MedianRate 20 20  44483.665 932.7800 0.03200787        1          1
#> 6      <leaf>  9  9   8638.785 916.0772 0.01000000        0          0
#> 7      <leaf> 11 11  31279.672 946.4460 0.01000000        0          0

##############################
#The non-leaf nodes with the smallest complexities are: 
#  Node Num 5 (yval = 924.9, complexity = 0.046)
#  Node Num 3 (yval = 932.8, complexity = 0.032)

#Let's prune with cp set to 0.05 and make those two nodes leaves
fit_05 <- prune(fit, 0.05)
rpart.plot(fit_05)

fit_05$frame
#>          var  n wt       dev     yval complexity ncompete nsurrogate
#> 1 MedianRate 60 60 142627.68 916.1906 0.09688132        1          0
#> 2 MedianRate 40 40  89887.78 907.8959 0.09688132        1          0
#> 4     <leaf> 15 15  18247.78 879.4796 0.01000000        0          0
#> 5     <leaf> 25 25  52260.32 924.9457 0.04640628        0          0
#> 3     <leaf> 20 20  44483.66 932.7800 0.03200787        0          0

#Calculate statistics of these leaves
LeafStats_05 <- DF %>% mutate(LeafNum = fit_05$where) %>% 
  group_by(LeafNum) %>% 
  summarise(Mean = mean(Value), Median = median(Value), 
            Min = min(Value), Max = max(Value), N = n())
#> Warning: The `printer` argument is deprecated as of rlang 0.3.0.
#> This warning is displayed once per session.
LeafStats_05
#> # A tibble: 3 x 6
#>   LeafNum  Mean Median   Min   Max     N
#>     <int> <dbl>  <dbl> <dbl> <dbl> <int>
#> 1       3  879.   869.  810.  944.    15
#> 2       4  925.   924.  837. 1007.    25
#> 3       5  933.   930.  863. 1044.    20

Created on 2019-10-14 by the reprex package (v0.3.0.9000)

Thanks for the response! Now I see how we prune lowest level, but still cannot figure out how we add Mean, Max, Min and Median values to the graph.

Yes, sorry, I completely missed what you were trying to do next.
The prp function does have a parameter called node.fun that sets the text of the node labels.

That's right! prp node.fun should solve this and I've been struggling since morning to get this done.
What I tried:

  1. Created LeafStats tribble as you kindly suggested
  2. Merged (leftjoin) it with the fit$frame table simply adding columns I needed
  3. Applied fun.node function as it was described in the documentation
  4. ...failed as somewhere in between my rpart object got some nodes turned into NANA

Do you see any simple way to overcome this?

I will take a look at this about 9 or 10 hours from now.

Here is a start. For some reason I do not understand, I had to replace the dev column in fit$frame with my label text. If I appended a column to fit.frame, I would get an error upon plotting. I have not worked on making the prp plot prettier. Check out the rpart.plot help to get a start on doing that.

#Invent data
set.seed(1)
DF1 <- data.frame(Band = round(rnorm(30, 5, 0.7), 0),
                  MedianRate = runif(30, min = 0.25, max = 2.25),
                  Value = rnorm(30, 900, 50))
DF2 <- data.frame(Band = round(rnorm(30, 6, 0.7), 0),
                  MedianRate = runif(30, min = 0.5, max = 2.5),
                  Value = rnorm(30, 940, 50))
DF <- rbind(DF1, DF2)

library(rpart)
library(rpart.plot)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

fit <- rpart(Value~Band+MedianRate, data=DF, method="anova")
rpart.plot(fit)


#Calculate statistics of the leaves
LeafStats <- DF %>% mutate(LeafNum = fit$where) %>% 
  group_by(LeafNum) %>% 
  summarise(Mean = mean(Value), Median = median(Value), 
            Min = min(Value), Max = max(Value), N = n())
LeafStats
#> # A tibble: 5 x 6
#>   LeafNum  Mean Median   Min   Max     N
#>     <int> <dbl>  <dbl> <dbl> <dbl> <int>
#> 1       3  879.   869.  810.  944.    15
#> 2       5  909.   920.  837.  999.    13
#> 3       6  942.   930.  894. 1007.    12
#> 4       8  916.   931.  863.  955.     9
#> 5       9  946.   929.  882. 1044.    11
LeafStats2 <- LeafStats %>% 
  mutate(Mean = round(Mean), Median = round(Median), Max = round(Max), Min = round(Min),
         Lab = paste(Mean, Median, Min, Max)) %>%
  select(LeafNum, Lab)
  
FRAME <- fit$frame
FRAME$RowNum <- 1:nrow(FRAME)
FRAME <- left_join(FRAME, LeafStats2, by = c("RowNum" = "LeafNum"))
FRAME$Lab <-  ifelse(is.na(FRAME$Lab), round(FRAME$yval), FRAME$Lab)
FRAME <- select(FRAME, -RowNum)
FRAME
#>          var  n wt        dev     yval complexity ncompete nsurrogate
#> 1 MedianRate 60 60 142627.676 916.1906 0.09688132        1          0
#> 2 MedianRate 40 40  89887.778 907.8959 0.09688132        1          0
#> 3     <leaf> 15 15  18247.778 879.4796 0.01000000        0          0
#> 4       Band 25 25  52260.318 924.9457 0.04640628        1          1
#> 5     <leaf> 13 13  30606.923 909.3128 0.01000000        0          0
#> 6     <leaf> 12 12  15034.576 941.8814 0.01000000        0          0
#> 7 MedianRate 20 20  44483.665 932.7800 0.03200787        1          1
#> 8     <leaf>  9  9   8638.785 916.0772 0.01000000        0          0
#> 9     <leaf> 11 11  31279.672 946.4460 0.01000000        0          0
#>                Lab
#> 1              916
#> 2              908
#> 3  879 869 810 944
#> 4              925
#> 5  909 920 837 999
#> 6 942 930 894 1007
#> 7              933
#> 8  916 931 863 955
#> 9 946 929 882 1044
fit$frame$dev <- FRAME$Lab
node.fun1 <- function(x, labs, digits, varlen) {
  x$frame$dev
}
prp(fit, node.fun = node.fun1)

Created on 2019-10-14 by the reprex package (v0.2.1)

Here is a somewhat improved version of my last code.

#Invent data
set.seed(1)
DF1 <- data.frame(Band = round(rnorm(30, 5, 0.7), 0),
                  MedianRate = runif(30, min = 0.25, max = 2.25),
                  Value = rnorm(30, 900, 50))
DF2 <- data.frame(Band = round(rnorm(30, 6, 0.7), 0),
                  MedianRate = runif(30, min = 0.5, max = 2.5),
                  Value = rnorm(30, 940, 50))
DF <- rbind(DF1, DF2)

library(rpart)
library(rpart.plot)
library(dplyr)


fit <- rpart(Value~Band+MedianRate, data=DF, method="anova")
rpart.plot(fit)


#Calculate statistics of the leaves
LeafStats <- DF %>% mutate(LeafNum = fit$where) %>% 
  group_by(LeafNum) %>% 
  summarise(Mean = mean(Value), Median = median(Value), 
            Min = min(Value), Max = max(Value), N = n())
LeafStats
#> # A tibble: 5 x 6
#>   LeafNum  Mean Median   Min   Max     N
#>     <int> <dbl>  <dbl> <dbl> <dbl> <int>
#> 1       3  879.   869.  810.  944.    15
#> 2       5  909.   920.  837.  999.    13
#> 3       6  942.   930.  894. 1007.    12
#> 4       8  916.   931.  863.  955.     9
#> 5       9  946.   929.  882. 1044.    11
LeafStats2 <- LeafStats %>% 
  mutate(Mean = round(Mean), Median = round(Median), Max = round(Max), Min = round(Min),
         Lab = paste(Mean, Median, Min, Max)) %>%
  select(LeafNum, Lab)
  
FRAME <- fit$frame
FRAME$RowNum <- 1:nrow(FRAME)
FRAME <- left_join(FRAME, LeafStats2, by = c("RowNum" = "LeafNum"))
FRAME$Lab <-  ifelse(is.na(FRAME$Lab), round(FRAME$yval), FRAME$Lab)

FRAME$Lab
#> [1] "916"              "908"              "879 869 810 944" 
#> [4] "925"              "909 920 837 999"  "942 930 894 1007"
#> [7] "933"              "916 931 863 955"  "946 929 882 1044"

node.fun1 <- function(x, labs, digits, varlen) {
  FRAME$Lab
}
prp(fit, type = 2, varlen = 0, faclen = 0, box.palette = "auto", fallen.leaves = TRUE, node.fun = node.fun1)

Created on 2019-10-15 by the reprex package (v0.2.1)

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.