The ultimate guide to decision trees in R – Part 1

You want to use a predictive model and after reading articles on the internet, you decided to go for a decision tree. You are familiar with different machine learning techniques, but now it’s time to master the decision tree algorithm. If this is your situation, then this is the right article for you!

This is going to be a very long tutorial with a lot of information, background and decision tree examples. Therefore I’ll split this article into three parts. The first part focuses on the theory, terminology and available algorithms. In the second part I will work out a full example explaining all the steps and calculations. In part three I will show you how to develop different decision trees in R by using different packages, and how to interpret the results.

 

What is a decision tree?

Let’s start with the basics. So what is a decision tree actually? In one of my previous articles about the 6 most used machine learning algorithms, I explained the decision tree basics. The decision tree algorithm is a supervised learning algorithm visualized as a tree. You can use the decision tree to predict numerical values and categorical values. Because it is so versatile, it is a good fit for many business problems.

It starts with the root node which includes all data points. The algorithm creates a decision rule by using different measures. I’ll explain more about the measures a bit further in this article. As soon as it finds the best decision rule, it creates two (some algorithms allow you to have 3 or more) branches with a node. It breaks down the full data set into subsets, where each node includes the data points belonging to the decision rule. This process continues and ends with leafs. When you want to predict the outcome of new data, you simply follow the decision rules. The leaf node determines the final classification of the new data. The category that appears most in the leaf node is also the predicted category.

 

Decision tree assumptions and best practices

Although the decision tree fits almost all business cases, there are still a couple of things to keep in mind.

  • In contrast to most other algorithms, the decision tree can handle missing values. It simply treats missing values as a different category. This is a big advance, it saves a lot of time to clean the data. Nevertheless, I recommend to look into missing values and understand the reason why data is missing. Sometimes it’s better to remove the date from the data set completely. Best way to find out is to test both cases.
  • Another thing to be aware of is overfitting. Overfitting happens when the model gives a prediction based on few data points. Therefore, the model performs very good on the train set, but additional data are likely to be far off the prediction. Try to avoid categorical variables with a lot of levels, since this typically causes overfitting.
  • Preferably, you want to have non-linear variables. The problem arises when there are two or more variables explaining the same. The algorithm picks the variable with the highest prediction power and will likely not (or later in the decision tree) use the other variables. This means that you loose information which could actually be valuable.
  • The advantage of the decision tree algorithm is that you don’t have to specify the distribution of the variables. This is also a drawback, because you loose information which could make the algorithm more powerful.

 

Algorithms with corresponding statistics and terminology

Just like any machine learning technique, the decision tree has multiple variations. I’ll explain four algorithms.

General terminology

Below, I’ll describe a few techniques with each its own calculations. But some measures and terminology are used for every type of decision tree.

  • Root node: the start node containing all data points.
  • Decision node: this is usually just called node and is split into other nodes.
  • Leafs: nodes which do not split into new nodes.
  • Splitting: when a root or decision node splits into new nodes, this process is called splitting.
  • Pruning: removing nodes from the tree is called pruning. We can use pre-pruning or post-pruning techniques to prevent overfitting. Some well-known possibilities for post-pruning are:
    1. Minimum node size: the minimum number of data points that should be in each leaf.
    2. Error estimation: we calculate the weighted sum of errors for each node. If the parent node has a smaller error than its child node, the child node is removed from the tree.
    3. Significance tests: with the Chi-squared test we determine whether the significance level is small enough for each leaf.

 

ID3 algorithm

The ID3 algorithm is often the default option in software when creating a decision tree. The algorithm uses a top-down approach and chooses the best split each iteration. The problem is that it can get stuck in a local optima. ID3 uses entropy and information gain to find the best split per node.

  • Entropy: the entropy tells us something about the uncertainty of a variable. When all data points belong to the same category level, the entropy is zero. When the data points are equally divided over the category levels, the entropy is one. The entropy can be written as the following formula:
    \sum_{x\in X}^{} -p(x) log_{2} p(x)
    The formula results in a graph like this:
    Entropy for decision tree
  • Information gain: the variable with the highest information gain is the one to be split in the next iteration. The information gain is the difference between the current entropy minus the entropy of the next iteration.
    IG(T,a)=H(T)-H(T|a)

 

CART algorithm

The CART decision tree is a binary tree and uses the Gini index to split each node.

  • The Gini index can be seen as a cost function which we want to minimize. It is always a number between 0 and 0.5. The outcome is 0 if all data points are of the same category level. The outcome is 0.5 if the data points are equally spread over two category levels.

Gini=\sum_{i\neq j}^{} p(i)p(j)

 

In part two I will work out a full example and in part three I will show you all the code needed in R. So come back soon to check out the other two parts!

 

 

World full of data author

Who I am


Hi! My name is Claudia, a freelance data analyst/scientist. This is my space on the internet where I share knowledge and experience with everyone who wants to become a better analyst. Read more about my work as a freelancer here.

Share this post on

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.