6 min read

Machine Learning - Decision Trees

library(tidyverse)
library(rpart)
library(caret)
library(randomForest)
library(party)
cu_summary <- read_csv("cu.summary.csv")

Table of Content

  • 1 Introduction
  • 2 A simple tree model
  • 2.1 Tree entropy and information gain
  • 2.2 Visualize a simple tree model
  • 3 Prune tree
  • 4 Decision trees for regression
  • 5 Decision trees for classification
  • 6 Conditional inference trees
  • 7 Closing word

1 Introduction

Decision trees are ordered, directed trees that serve to represent decision rules. The graphical representation as a tree diagram illustrates hierarchically successive decisions. Decision trees are an important part of information processing. They are often used not only in the field of data mining to derive strategies for route guidance, but also in machine learning. The process helps us in decision-making, as it visualizes different possibilities and scenarios.

For this post the dataset cu.summary from the package “rpart” was used. A copy of the record is available at https://drive.google.com/open?id=1o8rinf_SUZzsfAc_K7nI6Z7uhYyy_bZi.

2 A simple tree model

Let’s create an own decision tree as a fitting example.

a <- c("cloudy", "rainy", "sunny", "cloudy", "cloudy", "rainy", "rainy", "cloudy", "sunny", "sunny", "rainy", "sunny", "sunny")
b <- c("low", "low", "high", "high", "low", "high", "high", "high", "low", "low", "low", "low", "high")
c <- c("high", "normal", "normal", "high", "normal", "high", "normal", "normal", "high", "normal", "normal", "high", "high")
d <- c("yes", "yes", "yes", "yes", "yes", "no", "no", "yes", "no", "yes", "yes", "no", "no")
df <- data.frame(a, b, c, d)
colnames(df) <- c("Sky.Condition", "Wind.Speed", "Humidity", "Result")
df
##    Sky.Condition Wind.Speed Humidity Result
## 1         cloudy        low     high    yes
## 2          rainy        low   normal    yes
## 3          sunny       high   normal    yes
## 4         cloudy       high     high    yes
## 5         cloudy        low   normal    yes
## 6          rainy       high     high     no
## 7          rainy       high   normal     no
## 8         cloudy       high   normal    yes
## 9          sunny        low     high     no
## 10         sunny        low   normal    yes
## 11         rainy        low   normal    yes
## 12         sunny        low     high     no
## 13         sunny       high     high     no

2.1 Tree entropy and information gain

Definition of entropy: Entropy is the measures of impurity, disorder or uncertainty in a bunch of examples.

Entropy controls how a Decision Tree decides to split the data. It actually effects how a Decision Tree draws its boundaries.

Definition of information gain: Information gain measures how much “information” a feature gives us about the class.

Why it matter ?

  • Information gain is the main key that is used by Decision Tree Algorithms to construct a Decision Tree.
  • Decision Trees algorithm will always tries to maximize Information gain.
  • An attribute with highest Information gain will tested/split first.

Within R you can use the VarImpPlot() function to get a quick glance to see how to split the tree from the top down.

fit <- randomForest(factor(Result) ~ ., data = df)
varImpPlot(fit)

2.2 Visualize a simple tree model

fit2 <- rpart(Result~Sky.Condition + Wind.Speed + Humidity, method = "anova", data = df, 
              control =rpart.control(minsplit =1,minbucket=1, cp=0))


plot(fit2, uniform = TRUE, margin=0.1)
text(fit2, use.n = TRUE, all=TRUE, cex=.8)

3 Prune tree

Overfitting is a general problem with decision trees. One solution is to prune the decision tree accordingly. Let’s look at the record we want to use for this example.

glimpse(cu_summary)
## Observations: 117
## Variables: 6
## $ X1          <chr> "Acura Integra 4", "Dodge Colt 4", "Dodge Omni 4",...
## $ Price       <int> 11950, 6851, 6995, 8895, 7402, 6319, 6695, 10125, ...
## $ Country     <chr> "Japan", "Japan", "USA", "USA", "USA", "Korea", "J...
## $ Reliability <chr> "Much better", NA, "Much worse", "better", "worse"...
## $ Mileage     <int> NA, NA, NA, 33, 33, 37, NA, NA, 32, NA, 32, 26, NA...
## $ Type        <chr> "Small", "Small", "Small", "Small", "Small", "Smal...

In this case we want to model a vehicle’s fuel efficiency, as given by the Mileage variable. Since Mileage is a numerical variable, it becomes a regression model as a result.

fit3 <- rpart(Mileage~Price + Reliability + Type + Country, method = "anova", data = cu_summary)

plot(fit3, uniform = TRUE, margin=0.1)
text(fit3, use.n = TRUE, all=TRUE, cex=.8)

A precise way of knowing which parts to prune is to look at a tree’s complexity paramete, often refered to as the “CP”, which you can request with the plotcp function.

plotcp(fit3)

The complexy parameter is the amount by which splitting that tree node will improve the relative error.In the figure above, splitting it once improve the error by 0.29, and then less so for each additional split. You can see from the plot that the relative error is minimized at a tree size of 3 and the complexity paramter is below the dotted line threshold.

You can extract all these values programmatically from the model’s cptable, as follows:

fit3$cptable
##           CP nsplit rel error    xerror       xstd
## 1 0.62288527      0 1.0000000 1.0295681 0.17828009
## 2 0.13206061      1 0.3771147 0.5311535 0.10538275
## 3 0.02544094      2 0.2450541 0.3769875 0.08448195
## 4 0.01772069      3 0.2196132 0.3727303 0.08512904
## 5 0.01000000      4 0.2018925 0.3498613 0.07613958

You can see that the error is minimized at tree size 3. Now let’s prune the previously created tree.

fit3.pruned <- prune(fit3, cp = fit3$cptable[which.min(fit3$cptable[,"xerror"]), "CP"])
par(mfrow = c(1, 2))

plot(fit3, uniform = TRUE, margin=0.1, main = "Original decision tree")
text(fit3, use.n = TRUE, all=TRUE, cex=.8)

plot(fit3.pruned, uniform = TRUE, margin=0.1, main = "Pruned decision tree")
text(fit3.pruned, use.n = TRUE, all=TRUE, cex=.8)

This example takes the complexity parameter and passes it to the prune() function to effectively eliminate any splits that don’t make the model reduce its error.

4 Decision trees for regression

cu.summary.compl <-  cu_summary[complete.cases(cu_summary), ]
nrow(cu.summary.compl)

data.samples <- sample(1:nrow(cu.summary.compl), nrow(cu.summary.compl)* 0.7, replace = FALSE)

training.data1 <- cu.summary.compl[data.samples, ]
test.data1 <- cu.summary.compl[-data.samples, ]

fit4 <- rpart(Mileage~Price + Reliability + Type + Country, method = "anova", data = training.data1)

fit4.pruned <- prune(fit4, cp = fit4$cptable[which.min(fit4$cptable[,"xerror"]), "CP"])

fit4.prediction <- predict(fit4.pruned, test.data1)

fit4.output <- data.frame(test.data1$Mileage, fit4.prediction)

fit4.RMSE <- sqrt(sum((fit4.output$test.data1.Mileage -fit4.output$fit4.prediction)^2)
                  /nrow(fit4.output))

fit4.RMSE

5 Decision trees for classification

cu.summary.compl <- cu_summary[complete.cases(cu_summary), ]

data.samples <- sample(1:nrow(cu.summary.compl), nrow(cu.summary.compl)* 0.7, replace = FALSE)

training.data2 <- cu.summary.compl[data.samples, ]
test.data2 <- cu.summary.compl[-data.samples, ]

fit5 <- rpart(Type~Price + Reliability + Mileage + Country, method = "class", data = training.data2)

fit5.pruned <- prune(fit5, cp = fit5$cptable[which.min(fit5$cptable[,"xerror"]), "CP"])

fit5.prediction <- predict(fit5.pruned, test.data2, type = "class")

table(fit5.prediction, test.data2$Type)

6 Conditional inference trees

Conditional inference trees estimate a regression relationship by binary recursive partitioning in a conditional inference framework.

cu.summary.new <- cu_summary[ ,2:6]

Here’s an example of a conditional inference tree regression

fit6 <- ctree(Mileage ~ Price + factor(Reliability) + factor(Type) + factor(Country), data = na.omit(cu.summary.new))
plot(fit6)

And here an other example of a conditional inference tree classification

fit7 <- ctree(factor(Type) ~ Price + factor(Reliability) + Mileage + factor(Country), data = na.omit(cu.summary.new))
plot(fit7)

7 Closing word

Advantages and disadvantages of decision trees:

Due to their structure, decision trees are easy to understand, interpret and visualize. In doing so, a variable check or feature selection is implicitly performed. Both numerical and non-numerical data can be processed simultaneously relatively little effort on the part of the user for the data preparation requires.

On the other hand, too complex trees can be created that do not generalize the data well. This is called over-fitting. Small variations in the data can also make the trees unstable, creating a tree that does not solve the problem. This phenomenon is called variance.

Source

Burger, S. V. (2018). Introduction to Machine Learning with R: Rigorous Mathematical Analysis. " O’Reilly Media, Inc.“.